Instructions to use dn6/RosettaFold-3 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use dn6/RosettaFold-3 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("dn6/RosettaFold-3", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse files- README.md +140 -0
- __init__.py +19 -0
- before_denoise.py +223 -0
- decoders.py +169 -0
- denoise.py +135 -0
- modular_blocks.py +91 -0
- modular_config.json +7 -0
- modular_model_index.json +34 -0
- scheduler/__init__.py +4 -0
- scheduler/config.json +16 -0
- scheduler/model.py +90 -0
- transformer/__init__.py +4 -0
- transformer/config.json +25 -0
- transformer/diffusion_pytorch_model.safetensors +3 -0
- transformer/model.py +286 -0
README.md
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Protein Structure Prediction with Diffusers
|
| 2 |
+
|
| 3 |
+
A [diffusers](https://github.com/huggingface/diffusers) `ModularPipeline` wrapper for [RosettaFold3](https://doi.org/10.1101/2025.08.14.670328) (RF3) β a diffusion-based protein structure prediction model that predicts 3D atomic coordinates from amino acid sequences.
|
| 4 |
+
|
| 5 |
+
RF3 relies on [Foundry](https://github.com/RosettaCommons/foundry) for its underlying implementation and [AtomWorks](https://github.com/RosettaCommons/atomworks) for structure I/O. This package adds only the thin wrappers needed for diffusers integration.
|
| 6 |
+
|
| 7 |
+
## Getting Started
|
| 8 |
+
|
| 9 |
+
### Installation
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
pip install rc-foundry[all]
|
| 13 |
+
pip install diffusers
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
### Running with Diffusers
|
| 17 |
+
|
| 18 |
+
```python
|
| 19 |
+
import torch
|
| 20 |
+
from diffusers import ModularPipeline
|
| 21 |
+
|
| 22 |
+
pipe = ModularPipeline.from_pretrained("dn6/RosettaFold-3", trust_remote_code=True)
|
| 23 |
+
pipe.load_components(device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True)
|
| 24 |
+
|
| 25 |
+
state = pipe(sequence="MKVLSEGDPWRK...")
|
| 26 |
+
print(state.output.xyz.shape) # [D, L, 3]
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Workflows
|
| 30 |
+
|
| 31 |
+
| Workflow | Trigger inputs | What runs |
|
| 32 |
+
|----------|---------------|-----------|
|
| 33 |
+
| `fold` | `sequence` | Full structure prediction (recycling trunk + diffusion) |
|
| 34 |
+
|
| 35 |
+
### Fold a Sequence
|
| 36 |
+
|
| 37 |
+
```python
|
| 38 |
+
state = pipe(sequence="MKVLSEGDPWRK...", output_type="cif.gz", output_path="prediction")
|
| 39 |
+
print(state.output.atom_array)
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
### Full Design Pipeline
|
| 43 |
+
|
| 44 |
+
RF3 is typically used as a validation step after backbone design with [RFdiffusion3](https://huggingface.co/dn6/RFDiffusion-3):
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
RFD3 (design backbone) β MPNN (design sequence) β RF3 (validate fold)
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
```python
|
| 51 |
+
import torch
|
| 52 |
+
from diffusers import AutoModel, ModularPipeline
|
| 53 |
+
|
| 54 |
+
# 1. Design a backbone + sequence
|
| 55 |
+
design_pipe = ModularPipeline.from_pretrained("dn6/RFDiffusion-3", trust_remote_code=True)
|
| 56 |
+
design_pipe.load_components(device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True)
|
| 57 |
+
|
| 58 |
+
mpnn = AutoModel.from_pretrained("dn6/RFDiffusion-3", subfolder="mpnn", trust_remote_code=True)
|
| 59 |
+
design_pipe.update_components(mpnn=mpnn)
|
| 60 |
+
|
| 61 |
+
state = design_pipe(contigs="100", temperature=0.1)
|
| 62 |
+
designed_sequence = state.mpnn_output.designed_sequence
|
| 63 |
+
|
| 64 |
+
# 2. Validate the fold
|
| 65 |
+
fold_pipe = ModularPipeline.from_pretrained("dn6/RosettaFold-3", trust_remote_code=True)
|
| 66 |
+
fold_pipe.load_components(device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True)
|
| 67 |
+
|
| 68 |
+
state = fold_pipe(sequence=designed_sequence, output_type="cif.gz", output_path="prediction")
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
## Customizing Workflows
|
| 72 |
+
|
| 73 |
+
```python
|
| 74 |
+
# Inspect the pipeline structure
|
| 75 |
+
print(pipe.blocks)
|
| 76 |
+
|
| 77 |
+
# Add a custom block
|
| 78 |
+
from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState
|
| 79 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam
|
| 80 |
+
|
| 81 |
+
class ComputeRadiusOfGyration(ModularPipelineBlocks):
|
| 82 |
+
@property
|
| 83 |
+
def inputs(self):
|
| 84 |
+
return [InputParam("xyz", required=True)]
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def intermediate_outputs(self):
|
| 88 |
+
return [OutputParam("radius_of_gyration")]
|
| 89 |
+
|
| 90 |
+
def __call__(self, components, state):
|
| 91 |
+
block_state = self.get_block_state(state)
|
| 92 |
+
xyz = block_state.xyz
|
| 93 |
+
centroid = xyz.mean(dim=-2, keepdim=True)
|
| 94 |
+
block_state.radius_of_gyration = ((xyz - centroid) ** 2).sum(-1).mean().sqrt()
|
| 95 |
+
self.set_block_state(state, block_state)
|
| 96 |
+
return components, state
|
| 97 |
+
|
| 98 |
+
pipe._blocks.sub_blocks.insert("rog", ComputeRadiusOfGyration(), index=3)
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## Output Types
|
| 102 |
+
|
| 103 |
+
| `output_type` | Additional output | Writes to disk |
|
| 104 |
+
|---|---|---|
|
| 105 |
+
| `"tensor"` | β | β |
|
| 106 |
+
| `"pdb"` | `pdb_string` | `.pdb` file |
|
| 107 |
+
| `"cif"` | `atom_array`, `atom_array_stack`, `trajectory_stack` | `.cif` via AtomWorks |
|
| 108 |
+
| `"cif.gz"` | Same as `"cif"` | `.cif.gz` compressed |
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
# CIF output with AtomArray
|
| 112 |
+
state = pipe(sequence="MKVLSEG...", output_type="cif.gz", output_path="fold_0")
|
| 113 |
+
atom_array = state.output.atom_array
|
| 114 |
+
|
| 115 |
+
# Denoising trajectory
|
| 116 |
+
trajectory = state.output.trajectory_stack
|
| 117 |
+
|
| 118 |
+
# PDB output
|
| 119 |
+
state = pipe(sequence="MKVLSEG...", output_type="pdb", output_path="fold_0.pdb")
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## Model Architecture
|
| 123 |
+
|
| 124 |
+
RF3 is a diffusion model with the same EDM noise schedule as RFdiffusion3 (200 steps), but conditioned on sequence/MSA/template representations from a large recycling trunk:
|
| 125 |
+
|
| 126 |
+
| Component | Subfolder | Description |
|
| 127 |
+
|-----------|-----------|-------------|
|
| 128 |
+
| `transformer` | `transformer/` | `RF3TransformerModel` (366M params) β FeatureInitializer + Recycler (48 pairformer blocks) + DiffusionModule (24 transformer blocks) + DistogramHead |
|
| 129 |
+
| `scheduler` | `scheduler/` | `RF3Scheduler` (EDM schedule, gamma_0=0.8) |
|
| 130 |
+
|
| 131 |
+
## Citation
|
| 132 |
+
|
| 133 |
+
```bibtex
|
| 134 |
+
@article{corley2025accelerating,
|
| 135 |
+
author = {Corley, Nathaniel and Mathis, Simon and Krishna, Rohith and Bauer, Magnus S and Thompson, Tuscan R and Ahern, Woody and Kazman, Maxwell W and Brent, Rafael I and Didi, Kieran and Kubaney, Andrew and others},
|
| 136 |
+
title = {Accelerating biomolecular modeling with AtomWorks and RF3},
|
| 137 |
+
journal = {bioRxiv},
|
| 138 |
+
year = {2025},
|
| 139 |
+
}
|
| 140 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0
|
| 3 |
+
|
| 4 |
+
from .transformer import RF3TransformerModel, RF3TransformerOutput
|
| 5 |
+
from .scheduler import RF3Scheduler
|
| 6 |
+
from .modular_blocks import (
|
| 7 |
+
RF3AutoBeforeDenoiseStep,
|
| 8 |
+
RF3AutoBlocks,
|
| 9 |
+
RF3AutoDecodeStep,
|
| 10 |
+
RF3AutoDenoiseStep,
|
| 11 |
+
)
|
| 12 |
+
from .before_denoise import (
|
| 13 |
+
RF3InputStep,
|
| 14 |
+
RF3PrepareLatentsStep,
|
| 15 |
+
RF3RecyclingStep,
|
| 16 |
+
RF3SetTimestepsStep,
|
| 17 |
+
)
|
| 18 |
+
from .denoise import RF3DenoiseStep
|
| 19 |
+
from .decoders import RF3DecodeStep, RF3PipelineOutput
|
before_denoise.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Pre-denoising steps for RF3: input processing, timestep setup, recycling trunk, latent preparation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from diffusers.utils import logging
|
| 13 |
+
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
|
| 14 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.get_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RF3InputStep(ModularPipelineBlocks):
|
| 21 |
+
"""Parse sequence input and prepare feature dict for RF3."""
|
| 22 |
+
|
| 23 |
+
model_name = "rf3"
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def description(self) -> str:
|
| 27 |
+
return "Parse sequence and optional MSA/template inputs for structure prediction."
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def inputs(self) -> List[InputParam]:
|
| 31 |
+
return [
|
| 32 |
+
InputParam("sequence", required=True, type_hint=str, description="Amino acid sequence (one-letter codes)"),
|
| 33 |
+
InputParam("f", type_hint=dict, description="Pre-built feature dict (overrides sequence)"),
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 38 |
+
return [
|
| 39 |
+
OutputParam("f", type_hint=dict, description="Feature dictionary for RF3"),
|
| 40 |
+
OutputParam("L", type_hint=int, description="Sequence length (num atoms)"),
|
| 41 |
+
OutputParam("I", type_hint=int, description="Num tokens"),
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
@torch.no_grad()
|
| 45 |
+
def __call__(self, components, state):
|
| 46 |
+
block_state = self.get_block_state(state)
|
| 47 |
+
|
| 48 |
+
f = block_state.f
|
| 49 |
+
sequence = block_state.sequence
|
| 50 |
+
|
| 51 |
+
if f is None:
|
| 52 |
+
# Build minimal feature dict from sequence
|
| 53 |
+
L = len(sequence)
|
| 54 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 55 |
+
|
| 56 |
+
# Map sequence to restype indices
|
| 57 |
+
AA_ORDER = "ARNDCQEGHILKMFPSTWYV"
|
| 58 |
+
restype = torch.zeros(L, 32, device=device)
|
| 59 |
+
for i, aa in enumerate(sequence):
|
| 60 |
+
idx = AA_ORDER.find(aa)
|
| 61 |
+
if idx >= 0:
|
| 62 |
+
restype[i, idx] = 1.0
|
| 63 |
+
else:
|
| 64 |
+
restype[i, 20] = 1.0 # unknown
|
| 65 |
+
|
| 66 |
+
f = {
|
| 67 |
+
"restype": restype,
|
| 68 |
+
"atom_to_token_map": torch.arange(L, device=device),
|
| 69 |
+
"is_ca": torch.ones(L, dtype=torch.bool, device=device),
|
| 70 |
+
"ref_pos": torch.zeros(L, 3, device=device),
|
| 71 |
+
"ref_charge": torch.zeros(L, device=device),
|
| 72 |
+
"ref_mask": torch.ones(L, device=device),
|
| 73 |
+
"ref_element": torch.zeros(L, 128, device=device),
|
| 74 |
+
"ref_atom_name_chars": torch.zeros(L, 4, 64, device=device),
|
| 75 |
+
}
|
| 76 |
+
else:
|
| 77 |
+
L = f.get("ref_element", f.get("restype")).shape[0]
|
| 78 |
+
|
| 79 |
+
block_state.f = f
|
| 80 |
+
block_state.L = L
|
| 81 |
+
block_state.I = L # token count = atom count for CA-only
|
| 82 |
+
|
| 83 |
+
self.set_block_state(state, block_state)
|
| 84 |
+
return components, state
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class RF3SetTimestepsStep(ModularPipelineBlocks):
|
| 88 |
+
"""Set up EDM noise schedule for RF3."""
|
| 89 |
+
|
| 90 |
+
model_name = "rf3"
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def description(self) -> str:
|
| 94 |
+
return "Construct EDM noise schedule for RF3 diffusion sampling."
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 98 |
+
return [ComponentSpec("scheduler", description="RF3 EDM scheduler")]
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def inputs(self) -> List[InputParam]:
|
| 102 |
+
return [
|
| 103 |
+
InputParam("num_inference_steps", default=None, type_hint=int),
|
| 104 |
+
InputParam("L", required=True, type_hint=int),
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 109 |
+
return [
|
| 110 |
+
OutputParam("noise_schedule", type_hint=torch.Tensor),
|
| 111 |
+
OutputParam("num_inference_steps", type_hint=int),
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
@torch.no_grad()
|
| 115 |
+
def __call__(self, components, state):
|
| 116 |
+
block_state = self.get_block_state(state)
|
| 117 |
+
|
| 118 |
+
if hasattr(components, "scheduler") and components.scheduler is not None:
|
| 119 |
+
noise_schedule = components.scheduler.get_noise_schedule()
|
| 120 |
+
else:
|
| 121 |
+
noise_schedule = torch.linspace(160.0 * 16.0, 4e-4 * 16.0, 200)
|
| 122 |
+
|
| 123 |
+
block_state.noise_schedule = noise_schedule
|
| 124 |
+
block_state.num_inference_steps = len(noise_schedule)
|
| 125 |
+
|
| 126 |
+
self.set_block_state(state, block_state)
|
| 127 |
+
return components, state
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class RF3RecyclingStep(ModularPipelineBlocks):
|
| 131 |
+
"""Run the recycling trunk (pairformer + MSA + templates)."""
|
| 132 |
+
|
| 133 |
+
model_name = "rf3"
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def description(self) -> str:
|
| 137 |
+
return "Run RF3 recycling trunk to produce single/pair representations."
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 141 |
+
return [ComponentSpec("transformer", description="RF3 transformer model")]
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def inputs(self) -> List[InputParam]:
|
| 145 |
+
return [
|
| 146 |
+
InputParam("f", required=True, type_hint=dict),
|
| 147 |
+
InputParam("n_recycles", default=None, type_hint=int),
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
@property
|
| 151 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 152 |
+
return [
|
| 153 |
+
OutputParam("single", type_hint=torch.Tensor, description="Single representation [I, c_s]"),
|
| 154 |
+
OutputParam("pair", type_hint=torch.Tensor, description="Pair representation [I, I, c_z]"),
|
| 155 |
+
OutputParam("s_inputs", type_hint=torch.Tensor, description="Input embeddings [I, c_s_inputs]"),
|
| 156 |
+
OutputParam("distogram", type_hint=torch.Tensor, description="Distogram prediction [I, I, bins]"),
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
@torch.no_grad()
|
| 160 |
+
def __call__(self, components, state):
|
| 161 |
+
block_state = self.get_block_state(state)
|
| 162 |
+
|
| 163 |
+
f = block_state.f
|
| 164 |
+
n_recycles = block_state.n_recycles
|
| 165 |
+
|
| 166 |
+
if hasattr(components, "transformer") and components.transformer is not None:
|
| 167 |
+
output = components.transformer(f=f, n_recycles=n_recycles)
|
| 168 |
+
block_state.single = output.single
|
| 169 |
+
block_state.pair = output.pair
|
| 170 |
+
block_state.distogram = output.distogram
|
| 171 |
+
block_state.s_inputs = None # populated inside forward
|
| 172 |
+
else:
|
| 173 |
+
# Placeholder when no model loaded
|
| 174 |
+
block_state.single = None
|
| 175 |
+
block_state.pair = None
|
| 176 |
+
block_state.distogram = None
|
| 177 |
+
block_state.s_inputs = None
|
| 178 |
+
|
| 179 |
+
self.set_block_state(state, block_state)
|
| 180 |
+
return components, state
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class RF3PrepareLatentsStep(ModularPipelineBlocks):
|
| 184 |
+
"""Prepare initial noised coordinates for diffusion sampling."""
|
| 185 |
+
|
| 186 |
+
model_name = "rf3"
|
| 187 |
+
|
| 188 |
+
@property
|
| 189 |
+
def description(self) -> str:
|
| 190 |
+
return "Sample initial Gaussian noise scaled by the first noise schedule value."
|
| 191 |
+
|
| 192 |
+
@property
|
| 193 |
+
def inputs(self) -> List[InputParam]:
|
| 194 |
+
return [
|
| 195 |
+
InputParam("generator", type_hint=torch.Generator),
|
| 196 |
+
InputParam("diffusion_batch_size", default=5, type_hint=int),
|
| 197 |
+
InputParam("L", required=True, type_hint=int),
|
| 198 |
+
InputParam("noise_schedule", required=True, type_hint=torch.Tensor),
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 203 |
+
return [
|
| 204 |
+
OutputParam("xyz", type_hint=torch.Tensor, description="Initial noised coords [D, L, 3]"),
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
@torch.no_grad()
|
| 208 |
+
def __call__(self, components, state):
|
| 209 |
+
block_state = self.get_block_state(state)
|
| 210 |
+
|
| 211 |
+
L = block_state.L
|
| 212 |
+
noise_schedule = block_state.noise_schedule
|
| 213 |
+
D = block_state.diffusion_batch_size or 5
|
| 214 |
+
generator = block_state.generator
|
| 215 |
+
|
| 216 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 217 |
+
c0 = noise_schedule[0]
|
| 218 |
+
xyz = c0 * torch.randn((D, L, 3), device=device, generator=generator)
|
| 219 |
+
|
| 220 |
+
block_state.xyz = xyz
|
| 221 |
+
|
| 222 |
+
self.set_block_state(state, block_state)
|
| 223 |
+
return components, state
|
decoders.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Decode step for RF3 β converts denoised coordinates to output structures.
|
| 6 |
+
Supports tensor, PDB, and CIF (via AtomWorks) output formats.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import List, Optional
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from atomworks.io.utils.io_utils import to_cif_file
|
| 15 |
+
from biotite.structure import AtomArray, AtomArrayStack, stack
|
| 16 |
+
|
| 17 |
+
from diffusers.utils import logging
|
| 18 |
+
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
|
| 19 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
AA_ORDER = "ARNDCQEGHILKMFPSTWYV"
|
| 25 |
+
AA_NAMES_3 = [
|
| 26 |
+
"ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
|
| 27 |
+
"LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _build_atom_array(xyz: torch.Tensor, sequence: Optional[str] = None) -> AtomArray:
|
| 32 |
+
xyz_np = xyz.detach().cpu().float().numpy()
|
| 33 |
+
L = xyz_np.shape[0]
|
| 34 |
+
arr = AtomArray(L)
|
| 35 |
+
arr.coord = xyz_np
|
| 36 |
+
arr.atom_name = np.full(L, "CA")
|
| 37 |
+
arr.element = np.full(L, "C")
|
| 38 |
+
arr.chain_id = np.full(L, "A")
|
| 39 |
+
arr.res_id = np.arange(1, L + 1)
|
| 40 |
+
if sequence:
|
| 41 |
+
arr.res_name = np.array([
|
| 42 |
+
AA_NAMES_3[AA_ORDER.find(aa)] if aa in AA_ORDER else "UNK"
|
| 43 |
+
for aa in sequence
|
| 44 |
+
])
|
| 45 |
+
else:
|
| 46 |
+
arr.res_name = np.full(L, "ALA")
|
| 47 |
+
return arr
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _build_atom_array_stack(xyz: torch.Tensor, sequence: Optional[str] = None) -> AtomArrayStack:
|
| 51 |
+
template = _build_atom_array(xyz[0], sequence)
|
| 52 |
+
arr_stack = stack([template for _ in range(xyz.shape[0])])
|
| 53 |
+
arr_stack.coord = xyz.detach().cpu().float().numpy()
|
| 54 |
+
return arr_stack
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@dataclass
|
| 58 |
+
class RF3PipelineOutput:
|
| 59 |
+
"""Output class for RF3 pipeline."""
|
| 60 |
+
|
| 61 |
+
xyz: torch.Tensor
|
| 62 |
+
atom_array: Optional[AtomArray] = None
|
| 63 |
+
atom_array_stack: Optional[AtomArrayStack] = None
|
| 64 |
+
trajectory_stack: Optional[AtomArrayStack] = None
|
| 65 |
+
distogram: Optional[torch.Tensor] = None
|
| 66 |
+
sequence: Optional[str] = None
|
| 67 |
+
pdb_string: Optional[str] = None
|
| 68 |
+
trajectory: Optional[List[torch.Tensor]] = None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class RF3DecodeStep(ModularPipelineBlocks):
|
| 72 |
+
"""
|
| 73 |
+
Decode step for RF3.
|
| 74 |
+
|
| 75 |
+
Supported ``output_type`` values: ``"tensor"``, ``"pdb"``, ``"cif"``, ``"cif.gz"``.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
model_name = "rf3"
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def description(self) -> str:
|
| 82 |
+
return "Convert predicted coordinates to output format (tensor/PDB/CIF)."
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def inputs(self) -> List[InputParam]:
|
| 86 |
+
return [
|
| 87 |
+
InputParam("output_type", default="tensor", type_hint=str),
|
| 88 |
+
InputParam("output_path", type_hint=str),
|
| 89 |
+
InputParam("xyz", required=True, type_hint=torch.Tensor),
|
| 90 |
+
InputParam("sequence", type_hint=str),
|
| 91 |
+
InputParam("distogram", type_hint=torch.Tensor),
|
| 92 |
+
InputParam("trajectory", type_hint=List[torch.Tensor]),
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 97 |
+
return [
|
| 98 |
+
OutputParam("output", type_hint=RF3PipelineOutput),
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
|
| 103 |
+
block_state = self.get_block_state(state)
|
| 104 |
+
|
| 105 |
+
xyz = block_state.xyz
|
| 106 |
+
sequence = block_state.sequence
|
| 107 |
+
distogram = block_state.distogram
|
| 108 |
+
trajectory = block_state.trajectory
|
| 109 |
+
output_type = block_state.output_type or "tensor"
|
| 110 |
+
output_path = block_state.output_path
|
| 111 |
+
|
| 112 |
+
pdb_string = None
|
| 113 |
+
atom_array = None
|
| 114 |
+
atom_array_stack = None
|
| 115 |
+
trajectory_stack = None
|
| 116 |
+
|
| 117 |
+
if output_type in ("cif", "cif.gz"):
|
| 118 |
+
atom_array = _build_atom_array(xyz[0], sequence)
|
| 119 |
+
if xyz.shape[0] > 1:
|
| 120 |
+
atom_array_stack = _build_atom_array_stack(xyz, sequence)
|
| 121 |
+
if trajectory:
|
| 122 |
+
traj_coords = torch.stack([t[0] for t in trajectory])
|
| 123 |
+
template = _build_atom_array(traj_coords[0], sequence)
|
| 124 |
+
trajectory_stack = stack([template for _ in range(traj_coords.shape[0])])
|
| 125 |
+
trajectory_stack.coord = traj_coords.detach().cpu().float().numpy()
|
| 126 |
+
|
| 127 |
+
if output_type == "pdb":
|
| 128 |
+
pdb_string = self._coords_to_pdb(xyz[0], sequence)
|
| 129 |
+
|
| 130 |
+
if output_path is not None:
|
| 131 |
+
import os
|
| 132 |
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
| 133 |
+
if output_type in ("cif", "cif.gz"):
|
| 134 |
+
to_write = atom_array_stack if atom_array_stack is not None else atom_array
|
| 135 |
+
base = output_path.rsplit(".", 1)[0] if "." in output_path else output_path
|
| 136 |
+
to_cif_file(to_write, base, file_type=output_type, include_entity_poly=False)
|
| 137 |
+
elif output_type == "pdb" and pdb_string:
|
| 138 |
+
with open(output_path, "w") as f:
|
| 139 |
+
f.write(pdb_string)
|
| 140 |
+
|
| 141 |
+
output = RF3PipelineOutput(
|
| 142 |
+
xyz=xyz,
|
| 143 |
+
atom_array=atom_array,
|
| 144 |
+
atom_array_stack=atom_array_stack,
|
| 145 |
+
trajectory_stack=trajectory_stack,
|
| 146 |
+
distogram=distogram,
|
| 147 |
+
sequence=sequence,
|
| 148 |
+
pdb_string=pdb_string,
|
| 149 |
+
trajectory=trajectory,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
block_state.output = output
|
| 153 |
+
self.set_block_state(state, block_state)
|
| 154 |
+
return components, state
|
| 155 |
+
|
| 156 |
+
def _coords_to_pdb(self, xyz: torch.Tensor, sequence: Optional[str] = None) -> str:
|
| 157 |
+
xyz_np = xyz.cpu().numpy()
|
| 158 |
+
L = xyz_np.shape[0]
|
| 159 |
+
lines = []
|
| 160 |
+
for i in range(L):
|
| 161 |
+
aa = sequence[i] if sequence and i < len(sequence) else "A"
|
| 162 |
+
aa3 = AA_NAMES_3[AA_ORDER.find(aa)] if aa in AA_ORDER else "UNK"
|
| 163 |
+
x, y, z = xyz_np[i, :]
|
| 164 |
+
lines.append(
|
| 165 |
+
f"ATOM {i+1:5d} CA {aa3:3s} A{i+1:4d} "
|
| 166 |
+
f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C "
|
| 167 |
+
)
|
| 168 |
+
lines.append("END")
|
| 169 |
+
return "\n".join(lines)
|
denoise.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Denoising loop for RF3.
|
| 6 |
+
|
| 7 |
+
Same EDM stochastic sampling as RFD3, but conditioned on trunk
|
| 8 |
+
representations (single S_I, pair Z_II) from the recycling step.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from typing import Callable, List
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from diffusers.utils import logging
|
| 16 |
+
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
|
| 17 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logger = logging.get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class RF3DenoiseStep(ModularPipelineBlocks):
|
| 24 |
+
"""
|
| 25 |
+
Iterative denoising step for RF3.
|
| 26 |
+
|
| 27 |
+
Uses trunk representations from the recycling step as conditioning
|
| 28 |
+
for the diffusion module at each denoising step.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
model_name = "rf3"
|
| 32 |
+
|
| 33 |
+
@property
|
| 34 |
+
def description(self) -> str:
|
| 35 |
+
return "Iteratively denoise protein structure conditioned on sequence/MSA representations."
|
| 36 |
+
|
| 37 |
+
@property
|
| 38 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 39 |
+
return [
|
| 40 |
+
ComponentSpec("transformer", description="RF3 transformer (provides diffusion_module)"),
|
| 41 |
+
ComponentSpec("scheduler", description="RF3 EDM scheduler"),
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def inputs(self) -> List[InputParam]:
|
| 46 |
+
return [
|
| 47 |
+
InputParam("xyz", required=True, type_hint=torch.Tensor, description="Initial noised coords [D, L, 3]"),
|
| 48 |
+
InputParam("noise_schedule", required=True, type_hint=torch.Tensor),
|
| 49 |
+
InputParam("f", required=True, type_hint=dict, description="Feature dictionary"),
|
| 50 |
+
InputParam("single", type_hint=torch.Tensor, description="Trunk single repr [I, c_s]"),
|
| 51 |
+
InputParam("pair", type_hint=torch.Tensor, description="Trunk pair repr [I, I, c_z]"),
|
| 52 |
+
InputParam("s_inputs", type_hint=torch.Tensor, description="Input embeddings [I, c_s_inputs]"),
|
| 53 |
+
InputParam("callback", type_hint=Callable),
|
| 54 |
+
InputParam("callback_steps", default=1, type_hint=int),
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 59 |
+
return [
|
| 60 |
+
OutputParam("xyz", type_hint=torch.Tensor, description="Denoised coords [D, L, 3]"),
|
| 61 |
+
OutputParam("trajectory", type_hint=List[torch.Tensor]),
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
@torch.no_grad()
|
| 65 |
+
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
|
| 66 |
+
block_state = self.get_block_state(state)
|
| 67 |
+
|
| 68 |
+
xyz = block_state.xyz
|
| 69 |
+
noise_schedule = block_state.noise_schedule
|
| 70 |
+
f = block_state.f
|
| 71 |
+
single = block_state.single
|
| 72 |
+
pair = block_state.pair
|
| 73 |
+
s_inputs = block_state.s_inputs
|
| 74 |
+
callback = block_state.callback
|
| 75 |
+
callback_steps = block_state.callback_steps or 1
|
| 76 |
+
|
| 77 |
+
X_L = xyz.clone()
|
| 78 |
+
D = X_L.shape[0]
|
| 79 |
+
device = X_L.device
|
| 80 |
+
noise_schedule = noise_schedule.to(device)
|
| 81 |
+
|
| 82 |
+
trajectory = []
|
| 83 |
+
|
| 84 |
+
has_transformer = hasattr(components, "transformer") and components.transformer is not None
|
| 85 |
+
has_scheduler = hasattr(components, "scheduler") and components.scheduler is not None
|
| 86 |
+
|
| 87 |
+
for step_num in range(len(noise_schedule) - 1):
|
| 88 |
+
c_t_minus_1 = noise_schedule[step_num]
|
| 89 |
+
c_t = noise_schedule[step_num + 1]
|
| 90 |
+
|
| 91 |
+
# Noise injection
|
| 92 |
+
if has_scheduler:
|
| 93 |
+
X_noisy, t_hat = components.scheduler.add_noise(X_L, c_t_minus_1, c_t)
|
| 94 |
+
else:
|
| 95 |
+
X_noisy = X_L
|
| 96 |
+
t_hat = c_t_minus_1
|
| 97 |
+
|
| 98 |
+
# Model forward (diffusion module conditioned on trunk)
|
| 99 |
+
if has_transformer:
|
| 100 |
+
t_batch = (t_hat.to(device).expand(D) if isinstance(t_hat, torch.Tensor)
|
| 101 |
+
else torch.full((D,), t_hat, device=device))
|
| 102 |
+
|
| 103 |
+
outs = components.transformer.diffusion_module(
|
| 104 |
+
X_noisy_L=X_noisy,
|
| 105 |
+
t=t_batch,
|
| 106 |
+
f=f,
|
| 107 |
+
S_inputs_I=s_inputs,
|
| 108 |
+
S_trunk_I=single,
|
| 109 |
+
Z_trunk_II=pair,
|
| 110 |
+
)
|
| 111 |
+
X_denoised = outs if isinstance(outs, torch.Tensor) else outs.get("X_L", outs)
|
| 112 |
+
else:
|
| 113 |
+
X_denoised = X_noisy
|
| 114 |
+
|
| 115 |
+
# Euler step
|
| 116 |
+
if has_scheduler:
|
| 117 |
+
X_L = components.scheduler.step(
|
| 118 |
+
xyz_pred=X_denoised, xyz_noisy=X_noisy,
|
| 119 |
+
c_t_minus_1=c_t_minus_1, c_t=c_t,
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
delta = (X_noisy - X_denoised) / (t_hat + 1e-8)
|
| 123 |
+
d_t = c_t - t_hat
|
| 124 |
+
X_L = X_noisy + d_t * delta
|
| 125 |
+
|
| 126 |
+
trajectory.append(X_denoised.clone())
|
| 127 |
+
|
| 128 |
+
if callback is not None and step_num % callback_steps == 0:
|
| 129 |
+
callback(step_num, c_t_minus_1, X_L)
|
| 130 |
+
|
| 131 |
+
block_state.xyz = X_L
|
| 132 |
+
block_state.trajectory = trajectory
|
| 133 |
+
|
| 134 |
+
self.set_block_state(state, block_state)
|
| 135 |
+
return components, state
|
modular_blocks.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0
|
| 3 |
+
|
| 4 |
+
from diffusers.utils import logging
|
| 5 |
+
from diffusers.modular_pipelines import AutoPipelineBlocks, SequentialPipelineBlocks
|
| 6 |
+
|
| 7 |
+
from .before_denoise import (
|
| 8 |
+
RF3InputStep,
|
| 9 |
+
RF3PrepareLatentsStep,
|
| 10 |
+
RF3RecyclingStep,
|
| 11 |
+
RF3SetTimestepsStep,
|
| 12 |
+
)
|
| 13 |
+
from .decoders import RF3DecodeStep
|
| 14 |
+
from .denoise import RF3DenoiseStep
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.get_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RF3BeforeDenoiseStep(SequentialPipelineBlocks):
|
| 21 |
+
"""Sequential block for pre-denoising: input β timesteps β recycling β latents."""
|
| 22 |
+
|
| 23 |
+
block_classes = [
|
| 24 |
+
RF3InputStep,
|
| 25 |
+
RF3SetTimestepsStep,
|
| 26 |
+
RF3RecyclingStep,
|
| 27 |
+
RF3PrepareLatentsStep,
|
| 28 |
+
]
|
| 29 |
+
block_names = ["input", "set_timesteps", "recycling", "prepare_latents"]
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def description(self):
|
| 33 |
+
return (
|
| 34 |
+
"Before denoise step:\n"
|
| 35 |
+
" - `RF3InputStep` parses sequence and builds feature dict\n"
|
| 36 |
+
" - `RF3SetTimestepsStep` constructs EDM noise schedule\n"
|
| 37 |
+
" - `RF3RecyclingStep` runs trunk recycler (pairformer + MSA + templates)\n"
|
| 38 |
+
" - `RF3PrepareLatentsStep` samples initial noised coordinates\n"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class RF3AutoBeforeDenoiseStep(AutoPipelineBlocks):
|
| 43 |
+
block_classes = [RF3BeforeDenoiseStep]
|
| 44 |
+
block_names = ["fold"]
|
| 45 |
+
block_trigger_inputs = [None]
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def description(self):
|
| 49 |
+
return "Before denoise step for RF3 structure prediction."
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class RF3AutoDenoiseStep(AutoPipelineBlocks):
|
| 53 |
+
block_classes = [RF3DenoiseStep]
|
| 54 |
+
block_names = ["denoise"]
|
| 55 |
+
block_trigger_inputs = [None]
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def description(self) -> str:
|
| 59 |
+
return "Denoise step for RF3 structure prediction."
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class RF3AutoDecodeStep(AutoPipelineBlocks):
|
| 63 |
+
block_classes = [RF3DecodeStep]
|
| 64 |
+
block_names = ["decode"]
|
| 65 |
+
block_trigger_inputs = [None]
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def description(self):
|
| 69 |
+
return "Decode step for RF3 β coordinates to tensor/PDB/CIF."
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class RF3AutoBlocks(SequentialPipelineBlocks):
|
| 73 |
+
"""Full RF3 structure prediction pipeline."""
|
| 74 |
+
|
| 75 |
+
block_classes = [
|
| 76 |
+
RF3AutoBeforeDenoiseStep,
|
| 77 |
+
RF3AutoDenoiseStep,
|
| 78 |
+
RF3AutoDecodeStep,
|
| 79 |
+
]
|
| 80 |
+
block_names = [
|
| 81 |
+
"before_denoise",
|
| 82 |
+
"denoise",
|
| 83 |
+
"decoder",
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def description(self):
|
| 88 |
+
return (
|
| 89 |
+
"Modular pipeline for protein structure prediction using RF3.\n"
|
| 90 |
+
"Provide `sequence` to predict a protein's 3D structure."
|
| 91 |
+
)
|
modular_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "RF3AutoBlocks",
|
| 3 |
+
"_diffusers_version": "0.37.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"ModularPipelineBlocks": "modular_blocks.RF3AutoBlocks"
|
| 6 |
+
}
|
| 7 |
+
}
|
modular_model_index.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_blocks_class_name": "RF3AutoBlocks",
|
| 3 |
+
"_class_name": "ModularPipeline",
|
| 4 |
+
"_diffusers_version": "0.37.0.dev0",
|
| 5 |
+
"transformer": [
|
| 6 |
+
null,
|
| 7 |
+
null,
|
| 8 |
+
{
|
| 9 |
+
"pretrained_model_name_or_path": "dn6/RosettaFold-3",
|
| 10 |
+
"subfolder": "transformer",
|
| 11 |
+
"type_hint": [
|
| 12 |
+
"diffusers",
|
| 13 |
+
"AutoModel"
|
| 14 |
+
],
|
| 15 |
+
"revision": null,
|
| 16 |
+
"variant": null
|
| 17 |
+
}
|
| 18 |
+
],
|
| 19 |
+
"scheduler": [
|
| 20 |
+
null,
|
| 21 |
+
null,
|
| 22 |
+
{
|
| 23 |
+
"pretrained_model_name_or_path": "dn6/RosettaFold-3",
|
| 24 |
+
"subfolder": "scheduler",
|
| 25 |
+
"type_hint": [
|
| 26 |
+
"diffusers",
|
| 27 |
+
"AutoModel"
|
| 28 |
+
],
|
| 29 |
+
"revision": null,
|
| 30 |
+
"variant": null,
|
| 31 |
+
"default_creation_method": "from_config"
|
| 32 |
+
}
|
| 33 |
+
]
|
| 34 |
+
}
|
scheduler/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0
|
| 3 |
+
|
| 4 |
+
from .model import RF3Scheduler
|
scheduler/config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "RF3Scheduler",
|
| 3 |
+
"_diffusers_version": "0.37.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoModel": "model.RF3Scheduler"
|
| 6 |
+
},
|
| 7 |
+
"num_timesteps": 200,
|
| 8 |
+
"sigma_data": 16.0,
|
| 9 |
+
"s_min": 4e-4,
|
| 10 |
+
"s_max": 160.0,
|
| 11 |
+
"p": 7.0,
|
| 12 |
+
"gamma_0": 0.8,
|
| 13 |
+
"gamma_min": 1.0,
|
| 14 |
+
"noise_scale": 1.003,
|
| 15 |
+
"step_scale": 1.5
|
| 16 |
+
}
|
scheduler/model.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
RF3 Scheduler.
|
| 6 |
+
|
| 7 |
+
A diffusers-compatible wrapper around the foundry EDM noise schedule
|
| 8 |
+
for RF3. Same schedule formula as RFD3 but with gamma_0=0.8 (vs 0.6).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 16 |
+
|
| 17 |
+
from rf3.diffusion_samplers.inference_sampler import SampleDiffusion
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RF3Scheduler(ConfigMixin):
|
| 21 |
+
"""
|
| 22 |
+
Diffusers-compatible scheduler wrapping the foundry RF3 EDM sampler.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
config_name = "config.json"
|
| 26 |
+
|
| 27 |
+
@register_to_config
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
num_timesteps: int = 200,
|
| 31 |
+
sigma_data: float = 16.0,
|
| 32 |
+
s_min: float = 4e-4,
|
| 33 |
+
s_max: float = 160.0,
|
| 34 |
+
p: float = 7.0,
|
| 35 |
+
gamma_0: float = 0.8,
|
| 36 |
+
gamma_min: float = 1.0,
|
| 37 |
+
noise_scale: float = 1.003,
|
| 38 |
+
step_scale: float = 1.5,
|
| 39 |
+
):
|
| 40 |
+
self._sampler = SampleDiffusion(
|
| 41 |
+
num_timesteps=num_timesteps,
|
| 42 |
+
min_t=0,
|
| 43 |
+
max_t=1,
|
| 44 |
+
sigma_data=sigma_data,
|
| 45 |
+
s_min=s_min,
|
| 46 |
+
s_max=s_max,
|
| 47 |
+
p=p,
|
| 48 |
+
gamma_0=gamma_0,
|
| 49 |
+
gamma_min=gamma_min,
|
| 50 |
+
noise_scale=noise_scale,
|
| 51 |
+
step_scale=step_scale,
|
| 52 |
+
solver="af3",
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def sampler(self) -> SampleDiffusion:
|
| 57 |
+
return self._sampler
|
| 58 |
+
|
| 59 |
+
def get_noise_schedule(self, device: torch.device = None) -> torch.Tensor:
|
| 60 |
+
"""Construct the EDM noise schedule."""
|
| 61 |
+
return self._sampler._construct_inference_noise_schedule(
|
| 62 |
+
device=device or torch.device("cpu")
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def add_noise(
|
| 66 |
+
self,
|
| 67 |
+
xyz: torch.Tensor,
|
| 68 |
+
c_t_minus_1: torch.Tensor,
|
| 69 |
+
c_t: torch.Tensor,
|
| 70 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 71 |
+
"""Inject stochastic noise before the model call."""
|
| 72 |
+
gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0
|
| 73 |
+
t_hat = c_t_minus_1 * (gamma + 1.0)
|
| 74 |
+
noise_std = self._sampler.noise_scale * torch.sqrt(t_hat**2 - c_t_minus_1**2)
|
| 75 |
+
epsilon = noise_std * torch.randn_like(xyz)
|
| 76 |
+
return xyz + epsilon, t_hat
|
| 77 |
+
|
| 78 |
+
def step(
|
| 79 |
+
self,
|
| 80 |
+
xyz_pred: torch.Tensor,
|
| 81 |
+
xyz_noisy: torch.Tensor,
|
| 82 |
+
c_t_minus_1: torch.Tensor,
|
| 83 |
+
c_t: torch.Tensor,
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
"""Perform one Euler denoising step."""
|
| 86 |
+
gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0
|
| 87 |
+
t_hat = c_t_minus_1 * (gamma + 1.0)
|
| 88 |
+
delta = (xyz_noisy - xyz_pred) / t_hat
|
| 89 |
+
d_t = c_t - t_hat
|
| 90 |
+
return xyz_noisy + self._sampler.step_scale * d_t * delta
|
transformer/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0
|
| 3 |
+
|
| 4 |
+
from .model import RF3TransformerModel, RF3TransformerOutput
|
transformer/config.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "RF3TransformerModel",
|
| 3 |
+
"_diffusers_version": "0.37.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoModel": "model.RF3TransformerModel"
|
| 6 |
+
},
|
| 7 |
+
"c_s": 384,
|
| 8 |
+
"c_z": 128,
|
| 9 |
+
"c_atom": 128,
|
| 10 |
+
"c_atompair": 16,
|
| 11 |
+
"c_s_inputs": 449,
|
| 12 |
+
"c_token": 768,
|
| 13 |
+
"sigma_data": 16.0,
|
| 14 |
+
"n_pairformer_blocks": 48,
|
| 15 |
+
"n_diffusion_blocks": 24,
|
| 16 |
+
"n_atom_encoder_blocks": 3,
|
| 17 |
+
"n_atom_decoder_blocks": 3,
|
| 18 |
+
"n_msa_blocks": 4,
|
| 19 |
+
"n_template_blocks": 2,
|
| 20 |
+
"n_head": 16,
|
| 21 |
+
"n_pairformer_head": 16,
|
| 22 |
+
"n_recycles": 10,
|
| 23 |
+
"distogram_bins": 65,
|
| 24 |
+
"p_drop": 0.25
|
| 25 |
+
}
|
transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0d30d1003e326ef59e8154bc3cf3af928562715fa5410acb63de88c6486ae275
|
| 3 |
+
size 1466334428
|
transformer/model.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
RF3 (RosettaFold3) Transformer model.
|
| 6 |
+
|
| 7 |
+
A diffusers-compatible wrapper around the foundry RF3 model components.
|
| 8 |
+
Reuses FeatureInitializer, Recycler, DiffusionModule, and DistogramHead
|
| 9 |
+
from ``rf3.model.*`` directly, adding only the ModelMixin/ConfigMixin
|
| 10 |
+
interface needed for diffusers ModularPipeline integration.
|
| 11 |
+
|
| 12 |
+
RF3 is structurally similar to RFD3 but adds a trunk recycler (48
|
| 13 |
+
pairformer blocks + MSA + templates) for sequence-conditioned folding.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 24 |
+
|
| 25 |
+
from rf3.model.RF3_structure import DiffusionModule, DistogramHead, Recycler
|
| 26 |
+
from rf3.model.layers.pairformer_layers import FeatureInitializer
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class RF3TransformerOutput:
|
| 31 |
+
"""Output class for RF3 transformer."""
|
| 32 |
+
|
| 33 |
+
xyz: torch.Tensor # [D, L, 3]
|
| 34 |
+
distogram: Optional[torch.Tensor] = None # [I, I, bins]
|
| 35 |
+
single: Optional[torch.Tensor] = None # [I, c_s]
|
| 36 |
+
pair: Optional[torch.Tensor] = None # [I, I, c_z]
|
| 37 |
+
trajectory_noisy: Optional[list] = None # list of [D, L, 3]
|
| 38 |
+
trajectory_denoised: Optional[list] = None # list of [D, L, 3]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class RF3TransformerModel(ModelMixin, ConfigMixin):
|
| 42 |
+
"""
|
| 43 |
+
Diffusers-compatible wrapper around the foundry RF3 model.
|
| 44 |
+
|
| 45 |
+
Wraps FeatureInitializer, Recycler, DiffusionModule, and DistogramHead
|
| 46 |
+
to provide a diffusers ModelMixin/ConfigMixin interface.
|
| 47 |
+
|
| 48 |
+
State dict keys match the foundry checkpoint format via the
|
| 49 |
+
``feature_initializer.*``, ``recycler.*``, ``diffusion_module.*``,
|
| 50 |
+
and ``distogram_head.*`` prefixes.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
config_name = "config.json"
|
| 54 |
+
_supports_gradient_checkpointing = True
|
| 55 |
+
|
| 56 |
+
@register_to_config
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
c_s: int = 384,
|
| 60 |
+
c_z: int = 128,
|
| 61 |
+
c_atom: int = 128,
|
| 62 |
+
c_atompair: int = 16,
|
| 63 |
+
c_s_inputs: int = 449,
|
| 64 |
+
c_token: int = 768,
|
| 65 |
+
sigma_data: float = 16.0,
|
| 66 |
+
n_pairformer_blocks: int = 48,
|
| 67 |
+
n_diffusion_blocks: int = 24,
|
| 68 |
+
n_atom_encoder_blocks: int = 3,
|
| 69 |
+
n_atom_decoder_blocks: int = 3,
|
| 70 |
+
n_msa_blocks: int = 4,
|
| 71 |
+
n_template_blocks: int = 2,
|
| 72 |
+
n_head: int = 16,
|
| 73 |
+
n_pairformer_head: int = 16,
|
| 74 |
+
n_recycles: int = 10,
|
| 75 |
+
distogram_bins: int = 65,
|
| 76 |
+
p_drop: float = 0.25,
|
| 77 |
+
):
|
| 78 |
+
super().__init__()
|
| 79 |
+
|
| 80 |
+
# ββ FeatureInitializer ββββββββββββββββββββββββββββββββββββββββββ
|
| 81 |
+
self.feature_initializer = FeatureInitializer(
|
| 82 |
+
c_s=c_s,
|
| 83 |
+
c_z=c_z,
|
| 84 |
+
c_atom=c_atom,
|
| 85 |
+
c_atompair=c_atompair,
|
| 86 |
+
c_s_inputs=c_s_inputs,
|
| 87 |
+
input_feature_embedder={
|
| 88 |
+
"features": ["restype", "profile", "deletion_mean"],
|
| 89 |
+
"atom_attention_encoder": {
|
| 90 |
+
"c_token": c_s,
|
| 91 |
+
"c_atom_1d_features": 389,
|
| 92 |
+
"c_tokenpair": c_z,
|
| 93 |
+
"use_inv_dist_squared": True,
|
| 94 |
+
"atom_1d_features": [
|
| 95 |
+
"ref_pos", "ref_charge", "ref_mask",
|
| 96 |
+
"ref_element", "ref_atom_name_chars",
|
| 97 |
+
],
|
| 98 |
+
"atom_transformer": {
|
| 99 |
+
"n_queries": 32,
|
| 100 |
+
"n_keys": 128,
|
| 101 |
+
"diffusion_transformer": {
|
| 102 |
+
"n_block": 3,
|
| 103 |
+
"diffusion_transformer_block": {
|
| 104 |
+
"n_head": 4,
|
| 105 |
+
"no_residual_connection_between_attention_and_transition": True,
|
| 106 |
+
"kq_norm": True,
|
| 107 |
+
},
|
| 108 |
+
},
|
| 109 |
+
},
|
| 110 |
+
},
|
| 111 |
+
},
|
| 112 |
+
relative_position_encoding={"r_max": 32, "s_max": 2},
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# ββ Recycler (trunk) βββββββββββββββββββββββββββββββββββββββββββ
|
| 116 |
+
self.recycler = Recycler(
|
| 117 |
+
c_s=c_s,
|
| 118 |
+
c_z=c_z,
|
| 119 |
+
n_pairformer_blocks=n_pairformer_blocks,
|
| 120 |
+
pairformer_block={
|
| 121 |
+
"p_drop": p_drop,
|
| 122 |
+
"triangle_multiplication": {"d_hidden": 128},
|
| 123 |
+
"triangle_attention": {"n_head": 4, "d_hidden": 32},
|
| 124 |
+
"attention_pair_bias": {"n_head": n_head},
|
| 125 |
+
},
|
| 126 |
+
template_embedder={
|
| 127 |
+
"n_block": n_template_blocks,
|
| 128 |
+
"raw_template_dim": 108,
|
| 129 |
+
"c": 64,
|
| 130 |
+
"p_drop": p_drop,
|
| 131 |
+
},
|
| 132 |
+
msa_module={
|
| 133 |
+
"n_block": n_msa_blocks,
|
| 134 |
+
"c_m": 64,
|
| 135 |
+
"p_drop_msa": 0.15,
|
| 136 |
+
"p_drop_pair": p_drop,
|
| 137 |
+
"msa_subsample_embedder": {
|
| 138 |
+
"num_sequences": 1024,
|
| 139 |
+
"dim_raw_msa": 34,
|
| 140 |
+
"c_s_inputs": c_s_inputs,
|
| 141 |
+
"c_msa_embed": 64,
|
| 142 |
+
},
|
| 143 |
+
"outer_product": {
|
| 144 |
+
"c_msa_embed": 64,
|
| 145 |
+
"c_outer_product": 32,
|
| 146 |
+
"c_out": c_z,
|
| 147 |
+
},
|
| 148 |
+
"msa_pair_weighted_averaging": {
|
| 149 |
+
"n_heads": 8,
|
| 150 |
+
"c_weighted_average": 32,
|
| 151 |
+
"c_msa_embed": 64,
|
| 152 |
+
"c_z": c_z,
|
| 153 |
+
"separate_gate_for_every_channel": True,
|
| 154 |
+
},
|
| 155 |
+
"msa_transition": {"n": 4, "c": 64},
|
| 156 |
+
"triangle_multiplication_outgoing": {
|
| 157 |
+
"d_pair": c_z, "d_hidden": 128, "bias": True,
|
| 158 |
+
},
|
| 159 |
+
"triangle_multiplication_incoming": {
|
| 160 |
+
"d_pair": c_z, "d_hidden": 128, "bias": True,
|
| 161 |
+
},
|
| 162 |
+
"triangle_attention_starting": {
|
| 163 |
+
"d_pair": c_z, "n_head": 4, "d_hidden": 32, "p_drop": 0.0,
|
| 164 |
+
},
|
| 165 |
+
"triangle_attention_ending": {
|
| 166 |
+
"d_pair": c_z, "n_head": 4, "d_hidden": 32, "p_drop": 0.0,
|
| 167 |
+
},
|
| 168 |
+
"pair_transition": {"n": 4, "c": c_z},
|
| 169 |
+
},
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# ββ DiffusionModule ββββββββββββββββββββββββββββββββββββββββββββ
|
| 173 |
+
self.diffusion_module = DiffusionModule(
|
| 174 |
+
sigma_data=sigma_data,
|
| 175 |
+
c_atom=c_atom,
|
| 176 |
+
c_atompair=c_atompair,
|
| 177 |
+
c_token=c_token,
|
| 178 |
+
c_s=c_s,
|
| 179 |
+
c_z=c_z,
|
| 180 |
+
diffusion_conditioning={
|
| 181 |
+
"c_s_inputs": c_s_inputs,
|
| 182 |
+
"c_t_embed": 256,
|
| 183 |
+
"relative_position_encoding": {"r_max": 32, "s_max": 2},
|
| 184 |
+
},
|
| 185 |
+
atom_attention_encoder={
|
| 186 |
+
"c_tokenpair": c_z,
|
| 187 |
+
"c_atom_1d_features": 389,
|
| 188 |
+
"use_inv_dist_squared": True,
|
| 189 |
+
"atom_1d_features": [
|
| 190 |
+
"ref_pos", "ref_charge", "ref_mask",
|
| 191 |
+
"ref_element", "ref_atom_name_chars",
|
| 192 |
+
],
|
| 193 |
+
"atom_transformer": {
|
| 194 |
+
"n_queries": 32,
|
| 195 |
+
"n_keys": 128,
|
| 196 |
+
"diffusion_transformer": {
|
| 197 |
+
"n_block": n_atom_encoder_blocks,
|
| 198 |
+
"diffusion_transformer_block": {
|
| 199 |
+
"n_head": 4,
|
| 200 |
+
"no_residual_connection_between_attention_and_transition": True,
|
| 201 |
+
"kq_norm": True,
|
| 202 |
+
},
|
| 203 |
+
},
|
| 204 |
+
},
|
| 205 |
+
"broadcast_trunk_feats_on_1dim_old": False,
|
| 206 |
+
"use_chiral_features": True,
|
| 207 |
+
"no_grad_on_chiral_center": False,
|
| 208 |
+
},
|
| 209 |
+
diffusion_transformer={
|
| 210 |
+
"n_block": n_diffusion_blocks,
|
| 211 |
+
"diffusion_transformer_block": {
|
| 212 |
+
"n_head": n_head,
|
| 213 |
+
"no_residual_connection_between_attention_and_transition": True,
|
| 214 |
+
"kq_norm": True,
|
| 215 |
+
},
|
| 216 |
+
},
|
| 217 |
+
atom_attention_decoder={
|
| 218 |
+
"atom_transformer": {
|
| 219 |
+
"n_queries": 32,
|
| 220 |
+
"n_keys": 128,
|
| 221 |
+
"diffusion_transformer": {
|
| 222 |
+
"n_block": n_atom_decoder_blocks,
|
| 223 |
+
"diffusion_transformer_block": {
|
| 224 |
+
"n_head": 4,
|
| 225 |
+
"no_residual_connection_between_attention_and_transition": True,
|
| 226 |
+
"kq_norm": True,
|
| 227 |
+
},
|
| 228 |
+
},
|
| 229 |
+
},
|
| 230 |
+
},
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# ββ DistogramHead ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 234 |
+
self.distogram_head = DistogramHead(c_z=c_z, bins=distogram_bins)
|
| 235 |
+
|
| 236 |
+
self._n_recycles = n_recycles
|
| 237 |
+
|
| 238 |
+
def forward(
|
| 239 |
+
self,
|
| 240 |
+
f: dict,
|
| 241 |
+
n_recycles: Optional[int] = None,
|
| 242 |
+
diffusion_batch_size: int = 1,
|
| 243 |
+
coord_atom_lvl_to_be_noised: Optional[torch.Tensor] = None,
|
| 244 |
+
) -> RF3TransformerOutput:
|
| 245 |
+
"""
|
| 246 |
+
Forward pass: recycling trunk β diffusion sampling.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
f: Feature dictionary (sequence, MSA, templates, atom features).
|
| 250 |
+
n_recycles: Number of recycling iterations (default: config value).
|
| 251 |
+
diffusion_batch_size: Number of diffusion samples.
|
| 252 |
+
coord_atom_lvl_to_be_noised: Initial coordinates for partial diffusion.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
RF3TransformerOutput with predicted coordinates and distogram.
|
| 256 |
+
"""
|
| 257 |
+
n_recycles = n_recycles or self._n_recycles
|
| 258 |
+
|
| 259 |
+
# Pre-recycle: initialize features
|
| 260 |
+
initialized = self.feature_initializer(f)
|
| 261 |
+
S_inputs_I = initialized["S_inputs_I"]
|
| 262 |
+
S_I = initialized.get("S_init_I", initialized.get("S_I"))
|
| 263 |
+
Z_II = initialized.get("Z_init_II", initialized.get("Z_II"))
|
| 264 |
+
|
| 265 |
+
# Recycling trunk
|
| 266 |
+
for i in range(n_recycles):
|
| 267 |
+
ctx = torch.no_grad() if i < n_recycles - 1 else torch.enable_grad()
|
| 268 |
+
with ctx:
|
| 269 |
+
recycled = self.recycler(
|
| 270 |
+
S_I=S_I,
|
| 271 |
+
Z_II=Z_II,
|
| 272 |
+
S_inputs_I=S_inputs_I,
|
| 273 |
+
f=f,
|
| 274 |
+
)
|
| 275 |
+
S_I = recycled["S_I"]
|
| 276 |
+
Z_II = recycled["Z_II"]
|
| 277 |
+
|
| 278 |
+
# Distogram prediction
|
| 279 |
+
distogram = self.distogram_head(Z_II)
|
| 280 |
+
|
| 281 |
+
return RF3TransformerOutput(
|
| 282 |
+
xyz=torch.zeros(1), # placeholder β filled by sampler in denoise step
|
| 283 |
+
distogram=distogram,
|
| 284 |
+
single=S_I,
|
| 285 |
+
pair=Z_II,
|
| 286 |
+
)
|