# Copyright 2025 Dhruv Nair. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ ProteinMPNN / LigandMPNN model wrapper. A thin diffusers-compatible wrapper around the foundry MPNN model, following the same pattern as the transformer and scheduler wrappers. Reuses the foundry model implementation directly, adding only the ModelMixin/ConfigMixin interface for diffusers integration. """ from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from mpnn.model.mpnn import LigandMPNN, ProteinMPNN MODEL_CLASSES = { "protein_mpnn": ProteinMPNN, "ligand_mpnn": LigandMPNN, } @dataclass class MPNNModelOutput: """Output from the MPNN model wrapper.""" sequence_logits: torch.Tensor # [B, L, n_vocab] sequence_indices: torch.Tensor # [B, L] decoder_features: dict # full decoder output dict class MPNNModel(ModelMixin, ConfigMixin): """ Diffusers-compatible wrapper around the foundry ProteinMPNN / LigandMPNN. Wraps `mpnn.model.mpnn.ProteinMPNN` (or `LigandMPNN`) to provide a diffusers ModelMixin/ConfigMixin interface. All model logic is delegated to the foundry implementation. State dict keys match the foundry checkpoint format via the `model.*` prefix (stripped on load). """ config_name = "config.json" @register_to_config def __init__( self, model_type: str = "protein_mpnn", hidden_dim: int = 128, num_encoder_layers: int = 3, num_decoder_layers: int = 3, num_neighbors: int = 48, dropout_rate: float = 0.1, num_positional_embeddings: int = 16, min_rbf_mean: float = 2.0, max_rbf_mean: float = 22.0, num_rbf: int = 16, # LigandMPNN-specific num_context_atoms: int = 25, num_context_encoding_layers: int = 2, ): super().__init__() model_cls = MODEL_CLASSES.get(model_type) if model_cls is None: raise ValueError( f"Unknown model_type '{model_type}'. " f"Choose from: {list(MODEL_CLASSES.keys())}" ) common_kwargs = dict( num_node_features=hidden_dim, num_edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, num_neighbors=num_neighbors, dropout_rate=dropout_rate, num_positional_embeddings=num_positional_embeddings, min_rbf_mean=min_rbf_mean, max_rbf_mean=max_rbf_mean, num_rbf=num_rbf, ) if model_type == "ligand_mpnn": common_kwargs["num_context_atoms"] = num_context_atoms common_kwargs["num_context_encoding_layers"] = num_context_encoding_layers self.model = model_cls(**common_kwargs) def forward( self, X: torch.Tensor, S: Optional[torch.Tensor] = None, residue_mask: Optional[torch.Tensor] = None, designed_residue_mask: Optional[torch.Tensor] = None, chain_labels: Optional[torch.Tensor] = None, R_idx: Optional[torch.Tensor] = None, temperature: float = 0.1, **kwargs, ) -> MPNNModelOutput: """ Run ProteinMPNN / LigandMPNN sequence design. Args: X: Backbone atom coordinates [B, L, num_atoms, 3]. For ProteinMPNN: num_atoms=4 (N, CA, C, O). S: Ground-truth sequence tokens [B, L] (optional, for teacher forcing). residue_mask: Valid residue mask [B, L] (default: all valid). designed_residue_mask: Which residues to design [B, L] (default: all). chain_labels: Chain identifiers [B, L] (default: single chain). R_idx: Residue indices [B, L] (default: 0..L-1). temperature: Sampling temperature (default: 0.1). Returns: MPNNModelOutput with sequence logits and sampled indices. """ B, L = X.shape[0], X.shape[1] device = X.device if S is None: S = torch.zeros(B, L, dtype=torch.long, device=device) if residue_mask is None: residue_mask = torch.ones(B, L, dtype=torch.bool, device=device) if designed_residue_mask is None: designed_residue_mask = torch.ones(B, L, dtype=torch.bool, device=device) if chain_labels is None: chain_labels = torch.zeros(B, L, dtype=torch.long, device=device) if R_idx is None: R_idx = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) # Atom mask: mark all atoms as valid based on coordinate presence X_m = (X.abs().sum(dim=-1) > 0).float() # [B, L, num_atoms] network_input = { "X": X, "X_m": X_m, "S": S, "R_idx": R_idx, "chain_labels": chain_labels, "residue_mask": residue_mask, "designed_residue_mask": designed_residue_mask, "temperature": temperature, **kwargs, } output = self.model(network_input) logits = output["decoder_features"]["logits"] # [B, L, n_vocab] S_sampled = output["decoder_features"].get( "S_sampled", logits.argmax(dim=-1) ) return MPNNModelOutput( sequence_logits=logits, sequence_indices=S_sampled, decoder_features=output["decoder_features"], )