Instructions to use dn6/RFDiffusion-3 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use dn6/RFDiffusion-3 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("dn6/RFDiffusion-3", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright 2025 Dhruv Nair. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| ProteinMPNN / LigandMPNN model wrapper. | |
| A thin diffusers-compatible wrapper around the foundry MPNN model, | |
| following the same pattern as the transformer and scheduler wrappers. | |
| Reuses the foundry model implementation directly, adding only the | |
| ModelMixin/ConfigMixin interface for diffusers integration. | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from mpnn.model.mpnn import LigandMPNN, ProteinMPNN | |
| MODEL_CLASSES = { | |
| "protein_mpnn": ProteinMPNN, | |
| "ligand_mpnn": LigandMPNN, | |
| } | |
| class MPNNModelOutput: | |
| """Output from the MPNN model wrapper.""" | |
| sequence_logits: torch.Tensor # [B, L, n_vocab] | |
| sequence_indices: torch.Tensor # [B, L] | |
| decoder_features: dict # full decoder output dict | |
| class MPNNModel(ModelMixin, ConfigMixin): | |
| """ | |
| Diffusers-compatible wrapper around the foundry ProteinMPNN / LigandMPNN. | |
| Wraps `mpnn.model.mpnn.ProteinMPNN` (or `LigandMPNN`) to provide a | |
| diffusers ModelMixin/ConfigMixin interface. All model logic is delegated | |
| to the foundry implementation. | |
| State dict keys match the foundry checkpoint format via the `model.*` | |
| prefix (stripped on load). | |
| """ | |
| config_name = "config.json" | |
| def __init__( | |
| self, | |
| model_type: str = "protein_mpnn", | |
| hidden_dim: int = 128, | |
| num_encoder_layers: int = 3, | |
| num_decoder_layers: int = 3, | |
| num_neighbors: int = 48, | |
| dropout_rate: float = 0.1, | |
| num_positional_embeddings: int = 16, | |
| min_rbf_mean: float = 2.0, | |
| max_rbf_mean: float = 22.0, | |
| num_rbf: int = 16, | |
| # LigandMPNN-specific | |
| num_context_atoms: int = 25, | |
| num_context_encoding_layers: int = 2, | |
| ): | |
| super().__init__() | |
| model_cls = MODEL_CLASSES.get(model_type) | |
| if model_cls is None: | |
| raise ValueError( | |
| f"Unknown model_type '{model_type}'. " | |
| f"Choose from: {list(MODEL_CLASSES.keys())}" | |
| ) | |
| common_kwargs = dict( | |
| num_node_features=hidden_dim, | |
| num_edge_features=hidden_dim, | |
| hidden_dim=hidden_dim, | |
| num_encoder_layers=num_encoder_layers, | |
| num_decoder_layers=num_decoder_layers, | |
| num_neighbors=num_neighbors, | |
| dropout_rate=dropout_rate, | |
| num_positional_embeddings=num_positional_embeddings, | |
| min_rbf_mean=min_rbf_mean, | |
| max_rbf_mean=max_rbf_mean, | |
| num_rbf=num_rbf, | |
| ) | |
| if model_type == "ligand_mpnn": | |
| common_kwargs["num_context_atoms"] = num_context_atoms | |
| common_kwargs["num_context_encoding_layers"] = num_context_encoding_layers | |
| self.model = model_cls(**common_kwargs) | |
| def forward( | |
| self, | |
| X: torch.Tensor, | |
| S: Optional[torch.Tensor] = None, | |
| residue_mask: Optional[torch.Tensor] = None, | |
| designed_residue_mask: Optional[torch.Tensor] = None, | |
| chain_labels: Optional[torch.Tensor] = None, | |
| R_idx: Optional[torch.Tensor] = None, | |
| temperature: float = 0.1, | |
| **kwargs, | |
| ) -> MPNNModelOutput: | |
| """ | |
| Run ProteinMPNN / LigandMPNN sequence design. | |
| Args: | |
| X: Backbone atom coordinates [B, L, num_atoms, 3]. | |
| For ProteinMPNN: num_atoms=4 (N, CA, C, O). | |
| S: Ground-truth sequence tokens [B, L] (optional, for teacher forcing). | |
| residue_mask: Valid residue mask [B, L] (default: all valid). | |
| designed_residue_mask: Which residues to design [B, L] (default: all). | |
| chain_labels: Chain identifiers [B, L] (default: single chain). | |
| R_idx: Residue indices [B, L] (default: 0..L-1). | |
| temperature: Sampling temperature (default: 0.1). | |
| Returns: | |
| MPNNModelOutput with sequence logits and sampled indices. | |
| """ | |
| B, L = X.shape[0], X.shape[1] | |
| device = X.device | |
| if S is None: | |
| S = torch.zeros(B, L, dtype=torch.long, device=device) | |
| if residue_mask is None: | |
| residue_mask = torch.ones(B, L, dtype=torch.bool, device=device) | |
| if designed_residue_mask is None: | |
| designed_residue_mask = torch.ones(B, L, dtype=torch.bool, device=device) | |
| if chain_labels is None: | |
| chain_labels = torch.zeros(B, L, dtype=torch.long, device=device) | |
| if R_idx is None: | |
| R_idx = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) | |
| # Atom mask: mark all atoms as valid based on coordinate presence | |
| X_m = (X.abs().sum(dim=-1) > 0).float() # [B, L, num_atoms] | |
| network_input = { | |
| "X": X, | |
| "X_m": X_m, | |
| "S": S, | |
| "R_idx": R_idx, | |
| "chain_labels": chain_labels, | |
| "residue_mask": residue_mask, | |
| "designed_residue_mask": designed_residue_mask, | |
| "temperature": temperature, | |
| **kwargs, | |
| } | |
| output = self.model(network_input) | |
| logits = output["decoder_features"]["logits"] # [B, L, n_vocab] | |
| S_sampled = output["decoder_features"].get( | |
| "S_sampled", logits.argmax(dim=-1) | |
| ) | |
| return MPNNModelOutput( | |
| sequence_logits=logits, | |
| sequence_indices=S_sampled, | |
| decoder_features=output["decoder_features"], | |
| ) | |