RFDiffusion-3 / mpnn /model_mpnn.py
dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
4900749 verified
# Copyright 2025 Dhruv Nair. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ProteinMPNN / LigandMPNN model wrapper.
A thin diffusers-compatible wrapper around the foundry MPNN model,
following the same pattern as the transformer and scheduler wrappers.
Reuses the foundry model implementation directly, adding only the
ModelMixin/ConfigMixin interface for diffusers integration.
"""
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from mpnn.model.mpnn import LigandMPNN, ProteinMPNN
MODEL_CLASSES = {
"protein_mpnn": ProteinMPNN,
"ligand_mpnn": LigandMPNN,
}
@dataclass
class MPNNModelOutput:
"""Output from the MPNN model wrapper."""
sequence_logits: torch.Tensor # [B, L, n_vocab]
sequence_indices: torch.Tensor # [B, L]
decoder_features: dict # full decoder output dict
class MPNNModel(ModelMixin, ConfigMixin):
"""
Diffusers-compatible wrapper around the foundry ProteinMPNN / LigandMPNN.
Wraps `mpnn.model.mpnn.ProteinMPNN` (or `LigandMPNN`) to provide a
diffusers ModelMixin/ConfigMixin interface. All model logic is delegated
to the foundry implementation.
State dict keys match the foundry checkpoint format via the `model.*`
prefix (stripped on load).
"""
config_name = "config.json"
@register_to_config
def __init__(
self,
model_type: str = "protein_mpnn",
hidden_dim: int = 128,
num_encoder_layers: int = 3,
num_decoder_layers: int = 3,
num_neighbors: int = 48,
dropout_rate: float = 0.1,
num_positional_embeddings: int = 16,
min_rbf_mean: float = 2.0,
max_rbf_mean: float = 22.0,
num_rbf: int = 16,
# LigandMPNN-specific
num_context_atoms: int = 25,
num_context_encoding_layers: int = 2,
):
super().__init__()
model_cls = MODEL_CLASSES.get(model_type)
if model_cls is None:
raise ValueError(
f"Unknown model_type '{model_type}'. "
f"Choose from: {list(MODEL_CLASSES.keys())}"
)
common_kwargs = dict(
num_node_features=hidden_dim,
num_edge_features=hidden_dim,
hidden_dim=hidden_dim,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
num_neighbors=num_neighbors,
dropout_rate=dropout_rate,
num_positional_embeddings=num_positional_embeddings,
min_rbf_mean=min_rbf_mean,
max_rbf_mean=max_rbf_mean,
num_rbf=num_rbf,
)
if model_type == "ligand_mpnn":
common_kwargs["num_context_atoms"] = num_context_atoms
common_kwargs["num_context_encoding_layers"] = num_context_encoding_layers
self.model = model_cls(**common_kwargs)
def forward(
self,
X: torch.Tensor,
S: Optional[torch.Tensor] = None,
residue_mask: Optional[torch.Tensor] = None,
designed_residue_mask: Optional[torch.Tensor] = None,
chain_labels: Optional[torch.Tensor] = None,
R_idx: Optional[torch.Tensor] = None,
temperature: float = 0.1,
**kwargs,
) -> MPNNModelOutput:
"""
Run ProteinMPNN / LigandMPNN sequence design.
Args:
X: Backbone atom coordinates [B, L, num_atoms, 3].
For ProteinMPNN: num_atoms=4 (N, CA, C, O).
S: Ground-truth sequence tokens [B, L] (optional, for teacher forcing).
residue_mask: Valid residue mask [B, L] (default: all valid).
designed_residue_mask: Which residues to design [B, L] (default: all).
chain_labels: Chain identifiers [B, L] (default: single chain).
R_idx: Residue indices [B, L] (default: 0..L-1).
temperature: Sampling temperature (default: 0.1).
Returns:
MPNNModelOutput with sequence logits and sampled indices.
"""
B, L = X.shape[0], X.shape[1]
device = X.device
if S is None:
S = torch.zeros(B, L, dtype=torch.long, device=device)
if residue_mask is None:
residue_mask = torch.ones(B, L, dtype=torch.bool, device=device)
if designed_residue_mask is None:
designed_residue_mask = torch.ones(B, L, dtype=torch.bool, device=device)
if chain_labels is None:
chain_labels = torch.zeros(B, L, dtype=torch.long, device=device)
if R_idx is None:
R_idx = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
# Atom mask: mark all atoms as valid based on coordinate presence
X_m = (X.abs().sum(dim=-1) > 0).float() # [B, L, num_atoms]
network_input = {
"X": X,
"X_m": X_m,
"S": S,
"R_idx": R_idx,
"chain_labels": chain_labels,
"residue_mask": residue_mask,
"designed_residue_mask": designed_residue_mask,
"temperature": temperature,
**kwargs,
}
output = self.model(network_input)
logits = output["decoder_features"]["logits"] # [B, L, n_vocab]
S_sampled = output["decoder_features"].get(
"S_sampled", logits.argmax(dim=-1)
)
return MPNNModelOutput(
sequence_logits=logits,
sequence_indices=S_sampled,
decoder_features=output["decoder_features"],
)