diff --git a/bindevaluator.py b/bindevaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..a1aee7ee7de6250e10b671a7fedb7b265b02069f --- /dev/null +++ b/bindevaluator.py @@ -0,0 +1,182 @@ +import torch +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from datasets import load_from_disk +from transformers import AutoTokenizer +from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef, accuracy_score +from argparse import ArgumentParser +import os +import torch.distributed as dist +import pandas as pd +import pdb + +from modules.bindevaluator_modules import * # Import your model and other necessary classes/functions here + +def parse_motifs(motif: str) -> list: + parts = motif.split(',') + result = [] + + for part in parts: + part = part.strip() + if '-' in part: + start, end = map(int, part.split('-')) + result.extend(range(start, end + 1)) + else: + result.append(int(part)) + + result = [pos-1 for pos in result] + print(f'Target Motifs: {result}') + return torch.tensor(result) + + +class PeptideModel(pl.LightningModule): + def __init__(self, n_layers, d_model, d_hidden, n_head, + d_k, d_v, d_inner, dropout=0.2, + learning_rate=0.00001, max_epochs=15, kl_weight=1): + super(PeptideModel, self).__init__() + + self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") + # freeze all the esm_model parameters + for param in self.esm_model.parameters(): + param.requires_grad = False + + self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden, + n_head, d_k, d_v, d_inner, dropout=dropout) + + self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, + d_k, d_v, dropout=dropout) + + self.final_ffn = FFN(d_model, d_inner, dropout=dropout) + + self.output_projection_prot = nn.Linear(d_model, 1) + + self.learning_rate = learning_rate + self.max_epochs = max_epochs + self.kl_weight = kl_weight + + self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold + self.historical_memory = 0.9 + self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights + + def forward(self, binder_tokens, target_tokens): + peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state + protein_sequence = self.esm_model(**target_tokens).last_hidden_state + + prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ + seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, + protein_sequence) + + prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) + + prot_enc = self.final_ffn(prot_enc) + + prot_enc = self.output_projection_prot(prot_enc) + + return prot_enc + + +def calculate_score(target_sequence, binder_sequence, model, args): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") + anchor_tokens = tokenizer(target_sequence, return_tensors='pt', padding=True, truncation=True, max_length=40000) + positive_tokens = tokenizer(binder_sequence, return_tensors='pt', padding=True, truncation=True, max_length=40000) + + anchor_tokens['attention_mask'][0][0] = 0 + anchor_tokens['attention_mask'][0][-1] = 0 + positive_tokens['attention_mask'][0][0] = 0 + positive_tokens['attention_mask'][0][-1] = 0 + + target_tokens = {'input_ids': anchor_tokens["input_ids"].to(device), + 'attention_mask': anchor_tokens["attention_mask"].to(device)} + binder_tokens = {'input_ids': positive_tokens['input_ids'].to(device), + 'attention_mask': positive_tokens['attention_mask'].to(device)} + + model.eval() + + # pdb.set_trace() + + prediction = model(binder_tokens, target_tokens).squeeze(-1)[0][1:-1] + prediction = torch.sigmoid(prediction) + + return prediction, model.classification_threshold + + +def compute_metrics(true_residues, predicted_residues, length): + # Initialize the true and predicted lists with 0 + true_list = [0] * length + predicted_list = [0] * length + + # Set the values to 1 based on the provided lists + for index in true_residues: + true_list[index] = 1 + for index in predicted_residues: + predicted_list[index] = 1 + + # Compute the metrics + accuracy = accuracy_score(true_list, predicted_list) + f1 = f1_score(true_list, predicted_list) + mcc = matthews_corrcoef(true_list, predicted_list) + + return accuracy, f1, mcc + + +def main(): + parser = ArgumentParser() + parser.add_argument("-sm", default='/home/tc415/muPPIt/muppit/train_base_1/model-epoch=14-val_loss=0.40.ckpt', + help="File containing initial params", type=str) + parser.add_argument("-batch_size", type=int, default=32, help="Batch size") + parser.add_argument("-lr", type=float, default=1e-3) + parser.add_argument("-n_layers", type=int, default=6, help="Number of layers") + parser.add_argument("-d_model", type=int, default=64, help="Dimension of model") + parser.add_argument("-d_hidden", type=int, default=128, help="Dimension of CNN block") + parser.add_argument("-n_head", type=int, default=6, help="Number of heads") + parser.add_argument("-d_inner", type=int, default=64) + parser.add_argument("-target", type=str) + parser.add_argument("-binder", type=str) + parser.add_argument("-gt", type=str, default=None) + parser.add_argument("-motifs", type=str, default=None) + args = parser.parse_args() + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + model = PeptideModel.load_from_checkpoint(args.sm, + n_layers=args.n_layers, + d_model=args.d_model, + d_hidden=args.d_hidden, + n_head=args.n_head, + d_k=64, + d_v=128, + d_inner=64).to(device) + + prediction, _ = calculate_score(args.target, args.binder, model, args) + # print(prediction) + # print(model.classification_threshold) + + binding_site = [] + for i in range(len(prediction)): + if prediction[i] >= 0.5: + binding_site.append(i) + + print("Prediction: ", binding_site) + prediction = prediction.detach().cpu().tolist() + np.set_printoptions(precision=2, suppress=True) + print(prediction) + + if args.motifs is not None: + motifs = parse_motifs(args.motifs).tolist() + print(f"Motif Score: {torch.sum(prediction[motifs]) / len(motifs)}") + + if args.gt is not None: + L = len(args.target) + # print(L) + gt = parse_motifs(args.gt) + print("Ground Truth: ", gt) + + acc, f1, mcc = compute_metrics(gt, binding_site, L) + print(f"Accuracy={acc}\tF1={f1}\tMCC={mcc}") + + # print("Prediction Logits: ", prediction[binding_site]) + + +if __name__ == "__main__": + main() diff --git a/classifier_code/__init__.py b/classifier_code/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/classifier_code/binding_affinity_unpooled.py b/classifier_code/binding_affinity_unpooled.py new file mode 100644 index 0000000000000000000000000000000000000000..b67747b327754dbe6fef818f37c4193599bc61f0 --- /dev/null +++ b/classifier_code/binding_affinity_unpooled.py @@ -0,0 +1,321 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score, f1_score +from scipy.stats import spearmanr +from collections import defaultdict +import pandas as pd +import logging +import os +import torch.optim as optim +from datetime import datetime +from transformers import AutoModel, AutoConfig, AutoTokenizer +class UnpooledBindingPredictor(nn.Module): + def __init__(self, + esm_model_name="facebook/esm2_t33_650M_UR50D", + hidden_dim=512, + kernel_sizes=[3, 5, 7], + n_heads=8, + n_layers=3, + dropout=0.1, + freeze_esm=True): + super().__init__() + + # Define binding thresholds + self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM + self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM + + # Load ESM model for computing embeddings on the fly + self.esm_model = AutoModel.from_pretrained(esm_model_name) + self.config = AutoConfig.from_pretrained(esm_model_name) + + # Freeze ESM parameters if needed + if freeze_esm: + for param in self.esm_model.parameters(): + param.requires_grad = False + + # Get ESM hidden size + esm_dim = self.config.hidden_size + + # Output channels for CNN layers + output_channels_per_kernel = 64 + + # CNN layers for handling variable length sequences + self.protein_conv_layers = nn.ModuleList([ + nn.Conv1d( + in_channels=esm_dim, + out_channels=output_channels_per_kernel, + kernel_size=k, + padding='same' + ) for k in kernel_sizes + ]) + + self.binder_conv_layers = nn.ModuleList([ + nn.Conv1d( + in_channels=esm_dim, + out_channels=output_channels_per_kernel, + kernel_size=k, + padding='same' + ) for k in kernel_sizes + ]) + + # Calculate total features after convolution and pooling + total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2 + + # Project to same dimension after CNN processing + self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim) + self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim) + + self.protein_norm = nn.LayerNorm(hidden_dim) + self.binder_norm = nn.LayerNorm(hidden_dim) + + # Cross attention blocks with layer norm + self.cross_attention_layers = nn.ModuleList([ + nn.ModuleDict({ + 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), + 'norm1': nn.LayerNorm(hidden_dim), + 'ffn': nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * 4), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim * 4, hidden_dim) + ), + 'norm2': nn.LayerNorm(hidden_dim) + }) for _ in range(n_layers) + ]) + + # Prediction heads + self.shared_head = nn.Sequential( + nn.Linear(hidden_dim * 2, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + ) + + # Regression head + self.regression_head = nn.Linear(hidden_dim, 1) + + # Classification head (3 classes: tight, medium, loose binding) + self.classification_head = nn.Linear(hidden_dim, 3) + + def get_binding_class(self, affinity): + """Convert affinity values to class indices + 0: tight binding (>= 7.5) + 1: medium binding (6.0-7.5) + 2: weak binding (< 6.0) + """ + if isinstance(affinity, torch.Tensor): + tight_mask = affinity >= self.tight_threshold + weak_mask = affinity < self.weak_threshold + medium_mask = ~(tight_mask | weak_mask) + + classes = torch.zeros_like(affinity, dtype=torch.long) + classes[medium_mask] = 1 + classes[weak_mask] = 2 + return classes + else: + if affinity >= self.tight_threshold: + return 0 # tight binding + elif affinity < self.weak_threshold: + return 2 # weak binding + else: + return 1 # medium binding + + def compute_embeddings(self, input_ids, attention_mask=None): + """Compute ESM embeddings on the fly""" + esm_outputs = self.esm_model( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True + ) + + # Get the unpooled last hidden states (batch_size x seq_length x hidden_size) + return esm_outputs.last_hidden_state + + def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None): + """Process a sequence through CNN layers and pooling""" + # Transpose for CNN: [batch_size, hidden_size, seq_length] + x = unpooled_emb.transpose(1, 2) + + # Apply CNN layers and collect outputs + conv_outputs = [] + for conv in conv_layers: + conv_out = F.relu(conv(x)) + conv_outputs.append(conv_out) + + # Concatenate along channel dimension + conv_output = torch.cat(conv_outputs, dim=1) + + # Global pooling (both max and average) + # If attention mask is provided, use it to create a proper mask for pooling + if attention_mask is not None: + # Create a mask for pooling (1 for valid positions, 0 for padding) + # Expand mask to match conv_output channels + expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1) + + # Apply mask (set padding to large negative value for max pooling) + masked_output = conv_output.clone() + masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf')) + + # Max pooling along sequence dimension + max_pooled = torch.max(masked_output, dim=2)[0] + + # Average pooling (sum divided by number of valid positions) + sum_pooled = torch.sum(conv_output * expanded_mask, dim=2) + valid_positions = torch.sum(expanded_mask, dim=2) + valid_positions = torch.clamp(valid_positions, min=1.0) # Avoid division by zero + avg_pooled = sum_pooled / valid_positions + else: + # If no mask, use standard pooling + max_pooled = torch.max(conv_output, dim=2)[0] + avg_pooled = torch.mean(conv_output, dim=2) + + # Concatenate the pooled features + pooled = torch.cat([max_pooled, avg_pooled], dim=1) + + return pooled + + def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None): + # Compute embeddings on the fly using the ESM model + protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask) + binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask) + + # Process protein and binder sequences through CNN layers + protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask) + binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask) + + # Project to same dimension + protein = self.protein_norm(self.protein_projection(protein_features)) + binder = self.binder_norm(self.binder_projection(binder_features)) + + # Reshape for attention: from [batch_size, hidden_dim] to [1, batch_size, hidden_dim] + protein = protein.unsqueeze(0) + binder = binder.unsqueeze(0) + + # Cross attention layers + for layer in self.cross_attention_layers: + # Protein attending to binder + attended_protein = layer['attention']( + protein, binder, binder + )[0] + protein = layer['norm1'](protein + attended_protein) + protein = layer['norm2'](protein + layer['ffn'](protein)) + + # Binder attending to protein + attended_binder = layer['attention']( + binder, protein, protein + )[0] + binder = layer['norm1'](binder + attended_binder) + binder = layer['norm2'](binder + layer['ffn'](binder)) + + # Remove sequence dimension + protein_pool = protein.squeeze(0) + binder_pool = binder.squeeze(0) + + # Concatenate both representations + combined = torch.cat([protein_pool, binder_pool], dim=-1) + + # Shared features + shared_features = self.shared_head(combined) + + regression_output = self.regression_head(shared_features) + classification_logits = self.classification_head(shared_features) + + return regression_output, classification_logits + +def load_model(checkpoint_path, device): + """Load trained model from checkpoint.""" + checkpoint = torch.load(checkpoint_path, map_location=device) + # Import the model class from your module or redefine it here + + # Initialize model with the same parameters used during training + model = UnpooledBindingPredictor( + esm_model_name="facebook/esm2_t33_650M_UR50D", + hidden_dim=384, + kernel_sizes=[3, 5, 7], + n_heads=8, + n_layers=4, + dropout=0.14561457009902096, + freeze_esm=True + ).to(device) + + # Load the trained weights + model.load_state_dict(checkpoint['model_state_dict']) + model.eval() # Set to evaluation mode + + return model + + +def prepare_inputs(protein_sequence, binder_sequence, tokenizer, max_length=1024, device='cuda'): + """Tokenize protein and binder sequences.""" + protein_tokens = tokenizer( + protein_sequence, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True + ) + + binder_tokens = tokenizer( + binder_sequence, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True + ) + + return { + 'protein_input_ids': protein_tokens['input_ids'].to(device), + 'protein_attention_mask': protein_tokens['attention_mask'].to(device), + 'binder_input_ids': binder_tokens['input_ids'].to(device), + 'binder_attention_mask': binder_tokens['attention_mask'].to(device) + } + +# Perform prediction +def predict_binding(model, protein_sequence, binder_sequence, device='cuda'): + """Predict binding affinity between protein and binder sequences.""" + tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") + inputs = prepare_inputs(protein_sequence, binder_sequence, tokenizer, device=device) + + with torch.no_grad(): + regression_output, classification_logits = model( + inputs['protein_input_ids'], + inputs['binder_input_ids'], + inputs['protein_attention_mask'], + inputs['binder_attention_mask'] + ) + + # Get numerical prediction (pKd/pKi) + predicted_affinity = regression_output.item() + + # Get classification prediction (tight, medium, weak) + predicted_class_idx = torch.argmax(classification_logits, dim=1).item() + class_names = ['Tight binding', 'Medium binding', 'Weak binding'] + predicted_class = class_names[predicted_class_idx] + + # Get class probabilities + class_probs = F.softmax(classification_logits, dim=1).cpu().numpy()[0] + + return { + 'predicted_affinity': predicted_affinity, + 'binding_class': predicted_class, + 'class_probabilities': {name: prob for name, prob in zip(class_names, class_probs)}, + 'tight_threshold': model.tight_threshold, # 7.5 (≤ ~30nM) + 'weak_threshold': model.weak_threshold # 6.0 (> 1μM) + } + +# Example usage +if __name__ == "__main__": + # Set device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Load the model + model = load_model('../classifier_ckpt/binding_affinity_unpooled.pt', device) + + protein_sequence = "GSHMIEPNVISVRLFKRKVGGLGFLVKERVSKPPVIISDLIRGGAAEQSGLIQAGDIILAVNDRPLVDLSYDSALEVLRGIASETHVVLILRGPEGFTTHLETTFTGDGTPKTIRVTQPLGPPTKAV" + binder_sequence = "VVKVDSV" + + result = predict_binding(model, protein_sequence, binder, device) + print(f"Affinity Score: {result['predicted_affinity']}") diff --git a/classifier_code/binding_affinity_unpooled_2.py b/classifier_code/binding_affinity_unpooled_2.py new file mode 100644 index 0000000000000000000000000000000000000000..bf43e8cf6a538f58258c5f53030f2c111942f6ad --- /dev/null +++ b/classifier_code/binding_affinity_unpooled_2.py @@ -0,0 +1,356 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score, f1_score +from scipy.stats import spearmanr +from collections import defaultdict +import pandas as pd +import logging +import os +import torch.optim as optim +from datetime import datetime +from transformers import AutoModel, AutoConfig, AutoTokenizer + +import os + +# point HF_ENDPOINT at your mirror +# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" + +class UnpooledBindingPredictor(nn.Module): + def __init__(self, + esm_model_name="facebook/esm2_t33_650M_UR50D", + hidden_dim=512, + kernel_sizes=[3, 5, 7], + n_heads=8, + n_layers=3, + dropout=0.1, + freeze_esm=True): + super().__init__() + + # Define binding thresholds + self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM + self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM + + # Load ESM model for computing embeddings on the fly + self.esm_model = AutoModel.from_pretrained(esm_model_name) + self.config = AutoConfig.from_pretrained(esm_model_name) + + # Freeze ESM parameters if needed + if freeze_esm: + for param in self.esm_model.parameters(): + param.requires_grad = False + + # Get ESM hidden size + esm_dim = self.config.hidden_size + + # Output channels for CNN layers + output_channels_per_kernel = 64 + + # CNN layers for handling variable length sequences + self.protein_conv_layers = nn.ModuleList([ + nn.Conv1d( + in_channels=esm_dim, + out_channels=output_channels_per_kernel, + kernel_size=k, + padding='same' + ) for k in kernel_sizes + ]) + + self.binder_conv_layers = nn.ModuleList([ + nn.Conv1d( + in_channels=esm_dim, + out_channels=output_channels_per_kernel, + kernel_size=k, + padding='same' + ) for k in kernel_sizes + ]) + + # Calculate total features after convolution and pooling + total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2 + + # Project to same dimension after CNN processing + self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim) + self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim) + + self.protein_norm = nn.LayerNorm(hidden_dim) + self.binder_norm = nn.LayerNorm(hidden_dim) + + # Cross attention blocks with layer norm + self.cross_attention_layers = nn.ModuleList([ + nn.ModuleDict({ + 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), + 'norm1': nn.LayerNorm(hidden_dim), + 'ffn': nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * 4), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim * 4, hidden_dim) + ), + 'norm2': nn.LayerNorm(hidden_dim) + }) for _ in range(n_layers) + ]) + + # Prediction heads + self.shared_head = nn.Sequential( + nn.Linear(hidden_dim * 2, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + ) + + # Regression head + self.regression_head = nn.Linear(hidden_dim, 1) + + # Classification head (3 classes: tight, medium, loose binding) + self.classification_head = nn.Linear(hidden_dim, 3) + + def get_binding_class(self, affinity): + """Convert affinity values to class indices + 0: tight binding (>= 7.5) + 1: medium binding (6.0-7.5) + 2: weak binding (< 6.0) + """ + if isinstance(affinity, torch.Tensor): + tight_mask = affinity >= self.tight_threshold + weak_mask = affinity < self.weak_threshold + medium_mask = ~(tight_mask | weak_mask) + + classes = torch.zeros_like(affinity, dtype=torch.long) + classes[medium_mask] = 1 + classes[weak_mask] = 2 + return classes + else: + if affinity >= self.tight_threshold: + return 0 # tight binding + elif affinity < self.weak_threshold: + return 2 # weak binding + else: + return 1 # medium binding + + def compute_embeddings(self, input_ids, attention_mask=None): + """Compute ESM embeddings on the fly""" + esm_outputs = self.esm_model( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True + ) + + # Get the unpooled last hidden states (batch_size x seq_length x hidden_size) + return esm_outputs.last_hidden_state + + def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None): + """Process a sequence through CNN layers and pooling""" + # Transpose for CNN: [batch_size, hidden_size, seq_length] + x = unpooled_emb.transpose(1, 2) + + # Apply CNN layers and collect outputs + conv_outputs = [] + for conv in conv_layers: + conv_out = F.relu(conv(x)) + conv_outputs.append(conv_out) + + # Concatenate along channel dimension + conv_output = torch.cat(conv_outputs, dim=1) + + # Global pooling (both max and average) + # If attention mask is provided, use it to create a proper mask for pooling + if attention_mask is not None: + # Create a mask for pooling (1 for valid positions, 0 for padding) + # Expand mask to match conv_output channels + expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1) + + # Apply mask (set padding to large negative value for max pooling) + masked_output = conv_output.clone() + masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf')) + + # Max pooling along sequence dimension + max_pooled = torch.max(masked_output, dim=2)[0] + + # Average pooling (sum divided by number of valid positions) + sum_pooled = torch.sum(conv_output * expanded_mask, dim=2) + valid_positions = torch.sum(expanded_mask, dim=2) + valid_positions = torch.clamp(valid_positions, min=1.0) # Avoid division by zero + avg_pooled = sum_pooled / valid_positions + else: + # If no mask, use standard pooling + max_pooled = torch.max(conv_output, dim=2)[0] + avg_pooled = torch.mean(conv_output, dim=2) + + # Concatenate the pooled features + pooled = torch.cat([max_pooled, avg_pooled], dim=1) + + return pooled + + def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None): + # Compute embeddings on the fly using the ESM model + protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask) + binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask) + + # Process protein and binder sequences through CNN layers + protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask) + binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask) + + # Project to same dimension + protein = self.protein_norm(self.protein_projection(protein_features)) + binder = self.binder_norm(self.binder_projection(binder_features)) + + # Reshape for attention: from [batch_size, hidden_dim] to [1, batch_size, hidden_dim] + protein = protein.unsqueeze(0) + binder = binder.unsqueeze(0) + + # Cross attention layers + for layer in self.cross_attention_layers: + # Protein attending to binder + attended_protein = layer['attention']( + protein, binder, binder + )[0] + protein = layer['norm1'](protein + attended_protein) + protein = layer['norm2'](protein + layer['ffn'](protein)) + + # Binder attending to protein + attended_binder = layer['attention']( + binder, protein, protein + )[0] + binder = layer['norm1'](binder + attended_binder) + binder = layer['norm2'](binder + layer['ffn'](binder)) + + # Remove sequence dimension + protein_pool = protein.squeeze(0) + binder_pool = binder.squeeze(0) + + # Concatenate both representations + combined = torch.cat([protein_pool, binder_pool], dim=-1) + + # Shared features + shared_features = self.shared_head(combined) + + regression_output = self.regression_head(shared_features) + classification_logits = self.classification_head(shared_features) + + return regression_output, classification_logits + +def load_model(checkpoint_path, device): + """Load trained model from checkpoint.""" + checkpoint = torch.load(checkpoint_path, map_location=device) + # Import the model class from your module or redefine it here + + # Initialize model with the same parameters used during training + model = UnpooledBindingPredictor( + esm_model_name="facebook/esm2_t33_650M_UR50D", + hidden_dim=384, + kernel_sizes=[3, 5, 7], + n_heads=8, + n_layers=4, + dropout=0.14561457009902096, + freeze_esm=True + ).to(device) + + # Load the trained weights + model.load_state_dict(checkpoint['model_state_dict']) + model.eval() # Set to evaluation mode + + return model + + +def prepare_inputs(protein_sequence, binder_sequence, tokenizer, max_length=1024, device='cuda'): + """Tokenize protein and binder sequences.""" + protein_tokens = tokenizer( + protein_sequence, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True + ) + + binder_tokens = tokenizer( + binder_sequence, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True + ) + + return { + 'protein_input_ids': protein_tokens['input_ids'].to(device), + 'protein_attention_mask': protein_tokens['attention_mask'].to(device), + 'binder_input_ids': binder_tokens['input_ids'].to(device), + 'binder_attention_mask': binder_tokens['attention_mask'].to(device) + } + +# Perform prediction +def predict_binding(model, protein_sequence, binder_sequence, device='cuda'): + """Predict binding affinity between protein and binder sequences.""" + tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") + inputs = prepare_inputs(protein_sequence, binder_sequence, tokenizer, device=device) + + with torch.no_grad(): + regression_output, classification_logits = model( + inputs['protein_input_ids'], + inputs['binder_input_ids'], + inputs['protein_attention_mask'], + inputs['binder_attention_mask'] + ) + + # Get numerical prediction (pKd/pKi) + predicted_affinity = regression_output.item() + + # Get classification prediction (tight, medium, weak) + predicted_class_idx = torch.argmax(classification_logits, dim=1).item() + class_names = ['Tight binding', 'Medium binding', 'Weak binding'] + predicted_class = class_names[predicted_class_idx] + + # Get class probabilities + class_probs = F.softmax(classification_logits, dim=1).cpu().numpy()[0] + + return { + 'predicted_affinity': predicted_affinity, + 'binding_class': predicted_class, + 'class_probabilities': {name: prob for name, prob in zip(class_names, class_probs)}, + 'tight_threshold': model.tight_threshold, # 7.5 (≤ ~30nM) + 'weak_threshold': model.weak_threshold # 6.0 (> 1μM) + } + +# Example usage +if __name__ == "__main__": + # Set device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Load the model + model = load_model('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/binding_affinity_unpooled.pt', device) + + # Example protein sequences (replace with actual sequences) + binders = ['GLSKGCFGLKLDRIGSMSGLGC', 'RGLSDGFLKLKMGISGSLGC'] + protein_sequence = "RNLTLAVVLPEHNLSYAWAWPRVGPAVALAVEALGRALPVDLRFVSSELEGACSEYLAPLSAVDLKLYHDPDLLLGPGCVYPAASVARFASHWRLPLLTAGAVASGFSAKNDHYRTLVRTGPSAPKLGEFVVTLHGHFNWTARAALLYLDARTDDRPHYFTIEGVFEALQGSNLSVQHQVYAREPGGPEQATHFIRANGRIVYICGPLEMLHEILLQAQRENLTNGDYVFFYLDVFGESLRAGPTRATGRPWQDNRTREQAQALREAFQTVLVITYREPPNPEYQEFQNRLLIRAREDFGVELGPSLMNLIAGCFYDGILLYAEVLNETIQEGGTREDGLRIVEKMQGRRYHGVTGLVVMDKNNDRETDFVLWAMGDLDSGDFQPAAHYSGAEKQIWWTGRPIPWVKGAPPSDNPPCAFDLDDPSCDKTPLSTLAI" + + # name = "CLIC1_10_moppit" + # print(name) + # with open(f'/home/tc415/flow_matching/samples/unconditional_samples/12.txt', 'r') as f: + # binders = f.readlines() + # binders = [binder.strip() for binder in binders] + # binders = binders[:100] + + # # Make prediction + affinities = [] + for binder in binders: + result = predict_binding(model, protein_sequence, binder, device) + print(result['predicted_affinity']) + affinities.append(result['predicted_affinity']) + + # with open('/home/tc415/flow_matching/scores/affinity/EWSFLI1_12_unconditional.txt', 'w') as f: + # for score in affinities: + # f.write(str(score) + '\n') + + # print(sum(affinities) / len(affinities)) + + # with open(f'/home/tc415/flow_matching/scores/affinity/{name}.txt', 'w') as f: + # for score in affinities: + # f.write(str(round(score, 4)) + '\n') + + # Display results + # print(f"Predicted binding affinity (pKd/pKi): {result['predicted_affinity']:.2f}") + # print(f"Binding class: {result['binding_class']}") + # print("Class probabilities:") + # for class_name, prob in result['class_probabilities'].items(): + # print(f" {class_name}: {prob:.2f}") \ No newline at end of file diff --git a/classifier_code/half_life.py b/classifier_code/half_life.py new file mode 100644 index 0000000000000000000000000000000000000000..68715637d6dad5af4c0985a56623c70872332fb8 --- /dev/null +++ b/classifier_code/half_life.py @@ -0,0 +1,65 @@ +import numpy as np +import torch +import xgboost as xgb +from transformers import EsmModel, EsmTokenizer +import torch.nn as nn +import pdb + +class PeptideCNN(nn.Module): + def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate): + super().__init__() + self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1) + self.fc = nn.Linear(hidden_dims[1], output_dim) + self.dropout = nn.Dropout(dropout_rate) + self.predictor = nn.Linear(output_dim, 1) # For regression/classification + + self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device) + self.esm_model.eval() + + def forward(self, input_ids, attention_mask=None, return_features=False): + with torch.no_grad(): + x = self.esm_model(input_ids, attention_mask).last_hidden_state + # pdb.set_trace() + # x shape: (B, L, input_dim) + x = x.permute(0, 2, 1) # Reshape to (B, input_dim, L) for Conv1d + x = nn.functional.relu(self.conv1(x)) + x = self.dropout(x) + x = nn.functional.relu(self.conv2(x)) + x = self.dropout(x) + x = x.permute(0, 2, 1) # Reshape back to (B, L, hidden_dims[1]) + + # Global average pooling over the sequence dimension (L) + x = x.mean(dim=1) # Shape: (B, hidden_dims[1]) + + features = self.fc(x) # features shape: (B, output_dim) + if return_features: + return features + return self.predictor(features) # Output shape: (B, 1) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +input_dim = 1280 +hidden_dims = [input_dim // 2, input_dim // 4] +output_dim = input_dim // 8 +dropout_rate = 0.3 + +nn_model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device) +nn_model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth')) +nn_model.eval() + +def predict(inputs): + with torch.no_grad(): + prediction = nn_model(**inputs, return_features=False) + + return prediction.item() + +if __name__ == '__main__': + sequence = 'RGLSDGFLKLKMGISGSLGC' + + tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") + inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device) + + prediction = predict(inputs) + print(prediction) + print(f"Predicted half life of {sequence} is {(10**prediction):.4f} h") diff --git a/classifier_code/hemolysis_wt.py b/classifier_code/hemolysis_wt.py new file mode 100644 index 0000000000000000000000000000000000000000..192f3cb13f0eda08b8373e0942a78e8a764c568a --- /dev/null +++ b/classifier_code/hemolysis_wt.py @@ -0,0 +1,101 @@ +import sys +import os +sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts') +import xgboost as xgb +import torch +import numpy as np +import warnings +import numpy as np +from rdkit import Chem, rdBase, DataStructs +from transformers import AutoTokenizer, EsmModel + +rdBase.DisableLog('rdApp.error') +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + + +class Hemolysis: + def __init__(self): + # change model path + self.predictor = xgb.Booster(model_file='/home/tc415/flow_matching/classifier_ckpt/best_model_hemolysis.json') + + # Load ESM model and tokenizer + self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") + self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") + self.model.eval() + + def generate_embeddings(self, sequences): + """Generate ESM embeddings for protein sequences""" + embeddings = [] + + # Process sequences in batches to avoid memory issues + batch_size = 8 + for i in range(0, len(sequences), batch_size): + batch_sequences = sequences[i:i + batch_size] + + inputs = self.tokenizer( + batch_sequences, + padding=True, + truncation=True, + return_tensors="pt" + ) + + if torch.cuda.is_available(): + inputs = {k: v.cuda() for k, v in inputs.items()} + self.model = self.model.cuda() + + # Generate embeddings + with torch.no_grad(): + outputs = self.model(**inputs) + + # Get last hidden states + last_hidden_states = outputs.last_hidden_state + # pdb.set_trace() + # Compute mean pooling (excluding padding tokens) + attention_mask = inputs['attention_mask'].unsqueeze(-1) + masked_hidden_states = last_hidden_states * attention_mask + sum_hidden_states = masked_hidden_states.sum(dim=1) + seq_lengths = attention_mask.sum(dim=1) + batch_embeddings = sum_hidden_states / seq_lengths + + batch_embeddings = batch_embeddings.cpu().numpy() + embeddings.append(batch_embeddings) + + if embeddings: + return np.vstack(embeddings) + else: + return np.array([]) + + def get_scores(self, input_seqs: list): + scores = np.ones(len(input_seqs)) + features = self.generate_embeddings(input_seqs) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + probs = self.predictor.predict(features) + # return the probability of it being not hemolytic + return scores - probs + + def __call__(self, input_seqs: list): + scores = self.get_scores(input_seqs) + return scores + +def unittest(): + hemolysis = Hemolysis() + sequences = [ + "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", + "MSEGIRQAFVLAKSIWPARVARFTVDNRIRSLVKTYEAIKVDPYNPAFLEVLD" + ] + + scores = hemolysis(input_seqs=sequences) + print([1-score for score in scores]) + +if __name__ == '__main__': + unittest() \ No newline at end of file diff --git a/classifier_code/nonfouling_wt.py b/classifier_code/nonfouling_wt.py new file mode 100644 index 0000000000000000000000000000000000000000..831e193332cd231bedc1f9f385de76bb829bba48 --- /dev/null +++ b/classifier_code/nonfouling_wt.py @@ -0,0 +1,98 @@ +import sys +import os +import xgboost as xgb +import torch +import numpy as np +import warnings +import numpy as np +from rdkit import Chem, rdBase, DataStructs +from transformers import AutoTokenizer, EsmModel + +rdBase.DisableLog('rdApp.error') +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +class Nonfouling: + def __init__(self): + # change model path + self.predictor = xgb.Booster(model_file='../classifier_ckpt/best_model_nonfouling.json') + + # Load ESM model and tokenizer + self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") + self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") + self.model.eval() + + def generate_embeddings(self, sequences): + """Generate ESM embeddings for protein sequences""" + embeddings = [] + + # Process sequences in batches to avoid memory issues + batch_size = 8 + for i in range(0, len(sequences), batch_size): + batch_sequences = sequences[i:i + batch_size] + + inputs = self.tokenizer( + batch_sequences, + padding=True, + truncation=True, + return_tensors="pt" + ) + + if torch.cuda.is_available(): + inputs = {k: v.cuda() for k, v in inputs.items()} + self.model = self.model.cuda() + + # Generate embeddings + with torch.no_grad(): + outputs = self.model(**inputs) + + # Get last hidden states + last_hidden_states = outputs.last_hidden_state + + # Compute mean pooling (excluding padding tokens) + attention_mask = inputs['attention_mask'].unsqueeze(-1) + masked_hidden_states = last_hidden_states * attention_mask + sum_hidden_states = masked_hidden_states.sum(dim=1) + seq_lengths = attention_mask.sum(dim=1) + batch_embeddings = sum_hidden_states / seq_lengths + + batch_embeddings = batch_embeddings.cpu().numpy() + embeddings.append(batch_embeddings) + + if embeddings: + return np.vstack(embeddings) + else: + return np.array([]) + + def get_scores(self, input_seqs: list): + scores = np.zeros(len(input_seqs)) + features = self.generate_embeddings(input_seqs) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + scores = self.predictor.predict(features) + return scores + + def __call__(self, input_seqs: list): + scores = self.get_scores(input_seqs) + return scores + +def unittest(): + nonfouling = Nonfouling() + sequences = [ + "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", + "MSEGIRQAFVLAKSIWPARVARFTVDNRIRSLVKTYEAIKVDPYNPAFLEVLD" + ] + + scores = nonfouling(input_seqs=sequences) + print(scores) + +if __name__ == '__main__': + unittest() \ No newline at end of file diff --git a/classifier_code/solubility_wt.py b/classifier_code/solubility_wt.py new file mode 100644 index 0000000000000000000000000000000000000000..80dab47aadab6ee16354b124df98beca42fa36a6 --- /dev/null +++ b/classifier_code/solubility_wt.py @@ -0,0 +1,98 @@ +import sys +import os +import xgboost as xgb +import torch +import numpy as np +import warnings +import numpy as np +from rdkit import Chem, rdBase, DataStructs +from transformers import AutoTokenizer, EsmModel + +rdBase.DisableLog('rdApp.error') +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +class Solubility: + def __init__(self): + # change model path + self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_solubility.json') + + # Load ESM model and tokenizer + self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") + self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") + self.model.eval() + + def generate_embeddings(self, sequences): + """Generate ESM embeddings for protein sequences""" + embeddings = [] + + # Process sequences in batches to avoid memory issues + batch_size = 8 + for i in range(0, len(sequences), batch_size): + batch_sequences = sequences[i:i + batch_size] + + inputs = self.tokenizer( + batch_sequences, + padding=True, + truncation=True, + return_tensors="pt" + ) + + if torch.cuda.is_available(): + inputs = {k: v.cuda() for k, v in inputs.items()} + self.model = self.model.cuda() + + # Generate embeddings + with torch.no_grad(): + outputs = self.model(**inputs) + + # Get last hidden states + last_hidden_states = outputs.last_hidden_state + + # Compute mean pooling (excluding padding tokens) + attention_mask = inputs['attention_mask'].unsqueeze(-1) + masked_hidden_states = last_hidden_states * attention_mask + sum_hidden_states = masked_hidden_states.sum(dim=1) + seq_lengths = attention_mask.sum(dim=1) + batch_embeddings = sum_hidden_states / seq_lengths + + batch_embeddings = batch_embeddings.cpu().numpy() + embeddings.append(batch_embeddings) + + if embeddings: + return np.vstack(embeddings) + else: + return np.array([]) + + def get_scores(self, input_seqs: list): + scores = np.zeros(len(input_seqs)) + features = self.generate_embeddings(input_seqs) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + scores = self.predictor.predict(features) + return scores + + def __call__(self, input_seqs: list): + scores = self.get_scores(input_seqs) + return scores + +def unittest(): + solubility = Solubility() + sequences = [ + "GLSKGCFGLKLDRIGSMSGLGC", + "RGLSDGFLKLKMGISGSLGC" + ] + + scores = solubility(input_seqs=sequences) + print(scores) + +if __name__ == '__main__': + unittest() \ No newline at end of file diff --git a/flow_matching/__init__.py b/flow_matching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..719ae055c6a13d749beca75c7a496f3cd7cd8948 --- /dev/null +++ b/flow_matching/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +__version__ = "1.0.10" diff --git a/flow_matching/loss/__init__.py b/flow_matching/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..24ec1a9b0a64a11cc395e5043f6aaa184b5c2d52 --- /dev/null +++ b/flow_matching/loss/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from .generalized_loss import MixturePathGeneralizedKL + +__all__ = [ + "MixturePathGeneralizedKL", +] diff --git a/flow_matching/loss/generalized_loss.py b/flow_matching/loss/generalized_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4c7b363d657d7631921788e583413598ccad43cd --- /dev/null +++ b/flow_matching/loss/generalized_loss.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor +from torch.nn.modules.loss import _Loss + +from flow_matching.path import MixtureDiscreteProbPath + + +class MixturePathGeneralizedKL(_Loss): + r"""A generalized KL loss for discrete flow matching. + A class that measures the generalized KL of a discrete flow model :math:`p_{1|t}` w.r.t. a probability path given by ``path``. Note: this class is assuming that the model is trained on the same path. + + For a model trained on a space :math:`\mathcal{S} = \mathcal{T}^d`, :math:`\mathcal{T} = [K] = \set{1,2,\ldots,K}`, the loss is given by + + .. math:: + \ell_i(x_1, x_t, t) = -\frac{\dot{\kappa}_t}{1-\kappa_t} \biggr[ p_{1|t}(x_t^i|x_t) -\delta_{x^i_1}(x_t^i) + (1-\delta_{x^i_1}(x_t^i))\left(\log p_{1|t}(x_1^i|x_t)\right)\biggr], + + where :math:`\kappa_t` is the scheduler associated with ``path``. + + Args: + path (MixtureDiscreteProbPath): Probability path (x-prediction training). + reduction (str, optional): Specify the reduction to apply to the output ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction is applied to the output, ``'mean'``: the output is reduced by mean over sequence elements, ``'sum'``: the output is reduced by sum over sequence elements. Defaults to 'mean'. + """ + + def __init__(self, path: MixtureDiscreteProbPath, reduction: str = "mean") -> None: + super().__init__(None, None, reduction) + self.path = path + + def forward(self, logits: Tensor, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor: + r"""Evaluates the generalized KL loss. + + Args: + logits (Tensor): posterior model output (i.e., softmax(``logits``) :math:`=p_{1|t}(x|x_t)`), shape (batch, d, K). + x_1 (Tensor): target data point :math:`x_1 \sim q`, shape (batch, d). + x_t (Tensor): conditional sample at :math:`x_t \sim p_t(\cdot|x_1)`, shape (batch, d). + t (Tensor): times in :math:`[0,1]`, shape (batch). + + Raises: + ValueError: reduction value must be one of ``'none'`` | ``'mean'`` | ``'sum'``. + + Returns: + Tensor: Generalized KL loss. + """ + x_1_shape = x_1.shape + + # extract x_1 value of log(p_{1|t}(x|x_t)). + log_p_1t = torch.log_softmax(logits, dim=-1) + log_p_1t_x1 = torch.gather(log_p_1t, dim=-1, index=x_1.unsqueeze(-1)) + log_p_1t_x1 = log_p_1t_x1.view(*x_1_shape) + + # extract x_t value of p_{1|t}(x|x_t). + p_1t = torch.exp(log_p_1t) + p_1t_xt = torch.gather(p_1t, dim=-1, index=x_t.unsqueeze(-1)) + p_1t_xt = p_1t_xt.view(*x_1_shape) + + scheduler_output = self.path.scheduler(t) + + jump_coefficient = ( + scheduler_output.d_alpha_t / (1 - scheduler_output.alpha_t) + )[(...,) + (None,) * (x_1.dim() - 1)] + jump_coefficient = jump_coefficient.repeat(1, *x_1_shape[1:]) + delta_x1_xt = (x_t == x_1).to(log_p_1t.dtype) + + loss = -jump_coefficient * ( + p_1t_xt - delta_x1_xt + (1 - delta_x1_xt) * log_p_1t_x1 + ) + + mask = (x_1 != 1).to(loss.dtype) # 1 is the masked token + loss = loss * mask + + if self.reduction == "mean": + return torch.mean(loss) + elif self.reduction == "sum": + return torch.sum(loss) + elif self.reduction == "none": + return loss + else: + raise ValueError(f"{self.reduction} is not a valid value for reduction") diff --git a/flow_matching/path/__init__.py b/flow_matching/path/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..88d29a2575c1be14d2ec10bd9eda799e7599dd1d --- /dev/null +++ b/flow_matching/path/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from .affine import AffineProbPath, CondOTProbPath +from .geodesic import GeodesicProbPath +from .mixture import MixtureDiscreteProbPath +from .path import ProbPath +from .path_sample import DiscretePathSample, PathSample + + +__all__ = [ + "ProbPath", + "AffineProbPath", + "CondOTProbPath", + "MixtureDiscreteProbPath", + "GeodesicProbPath", + "PathSample", + "DiscretePathSample", +] diff --git a/flow_matching/path/affine.py b/flow_matching/path/affine.py new file mode 100644 index 0000000000000000000000000000000000000000..81cb7ed31f2434d03424ea9a5571a36bfc9f2681 --- /dev/null +++ b/flow_matching/path/affine.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from torch import Tensor + +from flow_matching.path.path import ProbPath +from flow_matching.path.path_sample import PathSample +from flow_matching.path.scheduler.scheduler import CondOTScheduler, Scheduler +from flow_matching.utils import expand_tensor_like + + +class AffineProbPath(ProbPath): + r"""The ``AffineProbPath`` class represents a specific type of probability path where the transformation between distributions is affine. + An affine transformation can be represented as: + + .. math:: + + X_t = \alpha_t X_1 + \sigma_t X_0, + + where :math:`X_t` is the transformed data point at time `t`. :math:`X_0` and :math:`X_1` are the source and target data points, respectively. :math:`\alpha_t` and :math:`\sigma_t` are the parameters of the affine transformation at time `t`. + + The scheduler is responsible for providing the time-dependent parameters :math:`\alpha_t` and :math:`\sigma_t`, as well as their derivatives, which define the affine transformation at any given time `t`. + + Using ``AffineProbPath`` in the flow matching framework: + + .. code-block:: python + + # Instantiates a probability path + my_path = AffineProbPath(...) + mse_loss = torch.nn.MSELoss() + + for x_1 in dataset: + # Sets x_0 to random noise + x_0 = torch.randn() + + # Sets t to a random value in [0,1] + t = torch.rand() + + # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1) + path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t) + + # Computes the MSE loss w.r.t. the velocity + loss = mse_loss(path_sample.dx_t, my_model(x_t, t)) + loss.backward() + + Args: + scheduler (Scheduler): An instance of a scheduler that provides the parameters :math:`\alpha_t`, :math:`\sigma_t`, and their derivatives over time. + + """ + + def __init__(self, scheduler: Scheduler): + self.scheduler = scheduler + + def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: + r"""Sample from the affine probability path: + + | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`. + | return :math:`X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0`, and the conditional velocity at :math:`X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0`. + + Args: + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). + + Returns: + PathSample: a conditional sample at :math:`X_t \sim p_t`. + """ + self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t) + + scheduler_output = self.scheduler(t) + + alpha_t = expand_tensor_like( + input_tensor=scheduler_output.alpha_t, expand_to=x_1 + ) + sigma_t = expand_tensor_like( + input_tensor=scheduler_output.sigma_t, expand_to=x_1 + ) + d_alpha_t = expand_tensor_like( + input_tensor=scheduler_output.d_alpha_t, expand_to=x_1 + ) + d_sigma_t = expand_tensor_like( + input_tensor=scheduler_output.d_sigma_t, expand_to=x_1 + ) + + # construct xt ~ p_t(x|x1). + x_t = sigma_t * x_0 + alpha_t * x_1 + dx_t = d_sigma_t * x_0 + d_alpha_t * x_1 + + return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t) + + def target_to_velocity(self, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor: + r"""Convert from x_1 representation to velocity. + + | given :math:`X_1`. + | return :math:`\dot{X}_t`. + + Args: + x_1 (Tensor): target data point. + x_t (Tensor): path sample at time t. + t (Tensor): time in [0,1]. + + Returns: + Tensor: velocity. + """ + scheduler_output = self.scheduler(t) + + alpha_t = scheduler_output.alpha_t + d_alpha_t = scheduler_output.d_alpha_t + sigma_t = scheduler_output.sigma_t + d_sigma_t = scheduler_output.d_sigma_t + + a_t = d_sigma_t / sigma_t + b_t = (d_alpha_t * sigma_t - d_sigma_t * alpha_t) / sigma_t + + return a_t * x_t + b_t * x_1 + + def epsilon_to_velocity(self, epsilon: Tensor, x_t: Tensor, t: Tensor) -> Tensor: + r"""Convert from epsilon representation to velocity. + + | given :math:`\epsilon`. + | return :math:`\dot{X}_t`. + + Args: + epsilon (Tensor): noise in the path sample. + x_t (Tensor): path sample at time t. + t (Tensor): time in [0,1]. + + Returns: + Tensor: velocity. + """ + scheduler_output = self.scheduler(t) + + alpha_t = scheduler_output.alpha_t + d_alpha_t = scheduler_output.d_alpha_t + sigma_t = scheduler_output.sigma_t + d_sigma_t = scheduler_output.d_sigma_t + + a_t = d_alpha_t / alpha_t + b_t = (d_sigma_t * alpha_t - d_alpha_t * sigma_t) / alpha_t + + return a_t * x_t + b_t * epsilon + + def velocity_to_target(self, velocity: Tensor, x_t: Tensor, t: Tensor) -> Tensor: + r"""Convert from velocity to x_1 representation. + + | given :math:`\dot{X}_t`. + | return :math:`X_1`. + + Args: + velocity (Tensor): velocity at the path sample. + x_t (Tensor): path sample at time t. + t (Tensor): time in [0,1]. + + Returns: + Tensor: target data point. + """ + scheduler_output = self.scheduler(t) + + alpha_t = scheduler_output.alpha_t + d_alpha_t = scheduler_output.d_alpha_t + sigma_t = scheduler_output.sigma_t + d_sigma_t = scheduler_output.d_sigma_t + + a_t = -d_sigma_t / (d_alpha_t * sigma_t - d_sigma_t * alpha_t) + b_t = sigma_t / (d_alpha_t * sigma_t - d_sigma_t * alpha_t) + + return a_t * x_t + b_t * velocity + + def epsilon_to_target(self, epsilon: Tensor, x_t: Tensor, t: Tensor) -> Tensor: + r"""Convert from epsilon representation to x_1 representation. + + | given :math:`\epsilon`. + | return :math:`X_1`. + + Args: + epsilon (Tensor): noise in the path sample. + x_t (Tensor): path sample at time t. + t (Tensor): time in [0,1]. + + Returns: + Tensor: target data point. + """ + scheduler_output = self.scheduler(t) + + alpha_t = scheduler_output.alpha_t + sigma_t = scheduler_output.sigma_t + + a_t = 1 / alpha_t + b_t = -sigma_t / alpha_t + + return a_t * x_t + b_t * epsilon + + def velocity_to_epsilon(self, velocity: Tensor, x_t: Tensor, t: Tensor) -> Tensor: + r"""Convert from velocity to noise representation. + + | given :math:`\dot{X}_t`. + | return :math:`\epsilon`. + + Args: + velocity (Tensor): velocity at the path sample. + x_t (Tensor): path sample at time t. + t (Tensor): time in [0,1]. + + Returns: + Tensor: noise in the path sample. + """ + scheduler_output = self.scheduler(t) + + alpha_t = scheduler_output.alpha_t + d_alpha_t = scheduler_output.d_alpha_t + sigma_t = scheduler_output.sigma_t + d_sigma_t = scheduler_output.d_sigma_t + + a_t = -d_alpha_t / (d_sigma_t * alpha_t - d_alpha_t * sigma_t) + b_t = alpha_t / (d_sigma_t * alpha_t - d_alpha_t * sigma_t) + + return a_t * x_t + b_t * velocity + + def target_to_epsilon(self, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor: + r"""Convert from x_1 representation to velocity. + + | given :math:`X_1`. + | return :math:`\epsilon`. + + Args: + x_1 (Tensor): target data point. + x_t (Tensor): path sample at time t. + t (Tensor): time in [0,1]. + + Returns: + Tensor: noise in the path sample. + """ + scheduler_output = self.scheduler(t) + + alpha_t = scheduler_output.alpha_t + sigma_t = scheduler_output.sigma_t + + a_t = 1 / sigma_t + b_t = -alpha_t / sigma_t + + return a_t * x_t + b_t * x_1 + + +class CondOTProbPath(AffineProbPath): + r"""The ``CondOTProbPath`` class represents a conditional optimal transport probability path. + + This class is a specialized version of the ``AffineProbPath`` that uses a conditional optimal transport scheduler to determine the parameters of the affine transformation. + + The parameters :math:`\alpha_t` and :math:`\sigma_t` for the conditional optimal transport path are defined as: + + .. math:: + + \alpha_t = t \quad \text{and} \quad \sigma_t = 1 - t. + """ + + def __init__(self): + self.scheduler = CondOTScheduler() diff --git a/flow_matching/path/geodesic.py b/flow_matching/path/geodesic.py new file mode 100644 index 0000000000000000000000000000000000000000..d04bf6770a98804f9bbcc7447c352be6ddec65dc --- /dev/null +++ b/flow_matching/path/geodesic.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from torch import Tensor +from torch.func import jvp, vmap + +from flow_matching.path.path import ProbPath + +from flow_matching.path.path_sample import PathSample +from flow_matching.path.scheduler import ConvexScheduler +from flow_matching.utils import expand_tensor_like + +from flow_matching.utils.manifolds import geodesic, Manifold + + +class GeodesicProbPath(ProbPath): + r"""The ``GeodesicProbPath`` class represents a specific type of probability path where the transformation between distributions is defined through the geodesic path. + Mathematically, a geodesic path can be represented as: + + .. math:: + + X_t = \psi_t(X_0 | X_1) = \exp_{X_1}(\kappa_t \log_{X_1}(X_0)), + + where :math:`X_t` is the transformed data point at time `t`, :math:`X_0` and :math:`X_1` are the source and target data points, respectively, and :math:`\kappa_t` is a scheduler. + + The scheduler is responsible for providing the time-dependent :math:`\kappa_t` and must be differentiable. + + Using ``GeodesicProbPath`` in the flow matching framework: + + .. code-block:: python + # Instantiates a manifold + manifold = FlatTorus() + + # Instantiates a scheduler + scheduler = CondOTScheduler() + + # Instantiates a probability path + my_path = GeodesicProbPath(scheduler, manifold) + mse_loss = torch.nn.MSELoss() + + for x_1 in dataset: + # Sets x_0 to random noise + x_0 = torch.randn() + + # Sets t to a random value in [0,1] + t = torch.rand() + + # Samples the conditional path :math:`X_t \sim p_t(X_t|X_0,X_1)` + path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t) + + # Computes the MSE loss w.r.t. the velocity + loss = mse_loss(path_sample.dx_t, my_model(x_t, t)) + loss.backward() + + Args: + scheduler (ConvexScheduler): The scheduler that provides :math:`\kappa_t`. + manifold (Manifold): The manifold on which the probability path is defined. + + """ + + def __init__(self, scheduler: ConvexScheduler, manifold: Manifold): + self.scheduler = scheduler + self.manifold = manifold + + def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: + r"""Sample from the Riemannian probability path with geodesic interpolation: + + | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`\kappa_t`. + | return :math:`X_0, X_1, X_t = \exp_{X_1}(\kappa_t \log_{X_1}(X_0))`, and the conditional velocity at :math:`X_t, \dot{X}_t`. + + Args: + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). + + Returns: + PathSample: A conditional sample at :math:`X_t \sim p_t`. + """ + self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t) + t = expand_tensor_like(input_tensor=t, expand_to=x_1[..., 0:1]).clone() + + def cond_u(x_0, x_1, t): + path = geodesic(self.manifold, x_0, x_1) + x_t, dx_t = jvp( + lambda t: path(self.scheduler(t).alpha_t), + (t,), + (torch.ones_like(t).to(t),), + ) + return x_t, dx_t + + x_t, dx_t = vmap(cond_u)(x_0, x_1, t) + x_t = x_t.reshape_as(x_1) + dx_t = dx_t.reshape_as(x_1) + + return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t) diff --git a/flow_matching/path/mixture.py b/flow_matching/path/mixture.py new file mode 100644 index 0000000000000000000000000000000000000000..28b4043ebe89a35408ada3a28cf33291ae277cd1 --- /dev/null +++ b/flow_matching/path/mixture.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + +from torch import Tensor + +from flow_matching.path.path import ProbPath + +from flow_matching.path.path_sample import DiscretePathSample +from flow_matching.path.scheduler import ConvexScheduler +from flow_matching.utils import expand_tensor_like, unsqueeze_to_match + + +class MixtureDiscreteProbPath(ProbPath): + r"""The ``MixtureDiscreteProbPath`` class defines a factorized discrete probability path. + + This path remains constant at the source data point :math:`X_0` until a random time, determined by the scheduler, when it flips to the target data point :math:`X_1`. + The scheduler determines the flip probability using the parameter :math:`\sigma_t`, which is a function of time `t`. Specifically, :math:`\sigma_t` represents the probability of remaining at :math:`X_0`, while :math:`1 - \sigma_t` is the probability of flipping to :math:`X_1`: + + .. math:: + + P(X_t = X_0) = \sigma_t \quad \text{and} \quad P(X_t = X_1) = 1 - \sigma_t, + + where :math:`\sigma_t` is provided by the scheduler. + + Example: + + .. code-block:: python + + >>> x_0 = torch.zeros((1, 3, 3)) + >>> x_1 = torch.ones((1, 3, 3)) + + >>> path = MixtureDiscreteProbPath(PolynomialConvexScheduler(n=1.0)) + >>> result = path.sample(x_0, x_1, t=torch.tensor([0.1])).x_t + >>> result + tensor([[[0.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 0.0]]]) + + >>> result = path.sample(x_0, x_1, t=torch.tensor([0.5])).x_t + >>> result + tensor([[[1.0, 0.0, 1.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0]]]) + + >>> result = path.sample(x_0, x_1, t=torch.tensor([1.0])).x_t + >>> result + tensor([[[1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0]]]) + + Args: + scheduler (ConvexScheduler): The scheduler that provides :math:`\sigma_t`. + """ + + def __init__(self, scheduler: ConvexScheduler): + assert isinstance( + scheduler, ConvexScheduler + ), "Scheduler for ConvexProbPath must be a ConvexScheduler." + + self.scheduler = scheduler + + def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample: + r"""Sample from the affine probability path: + | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`. + | return :math:`X_0, X_1, t`, and :math:`X_t \sim p_t`. + Args: + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). + + Returns: + DiscretePathSample: a conditional sample at :math:`X_t ~ p_t`. + """ + self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t) + + sigma_t = self.scheduler(t).sigma_t + + sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1) + + source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t + x_t = torch.where(condition=source_indices, input=x_0, other=x_1) + + return DiscretePathSample(x_t=x_t, x_1=x_1, x_0=x_0, t=t) + + def posterior_to_velocity( + self, posterior_logits: Tensor, x_t: Tensor, t: Tensor + ) -> Tensor: + r"""Convert the factorized posterior to velocity. + + | given :math:`p(X_1|X_t)`. In the factorized case: :math:`\prod_i p(X_1^i | X_t)`. + | return :math:`u_t`. + + Args: + posterior_logits (Tensor): logits of the x_1 posterior conditional on x_t, shape (..., vocab size). + x_t (Tensor): path sample at time t, shape (...). + t (Tensor): time in [0,1]. + + Returns: + Tensor: velocity. + """ + posterior = torch.softmax(posterior_logits, dim=-1) + vocabulary_size = posterior.shape[-1] + x_t = F.one_hot(x_t, num_classes=vocabulary_size) + t = unsqueeze_to_match(source=t, target=x_t) + + scheduler_output = self.scheduler(t) + + kappa_t = scheduler_output.alpha_t + d_kappa_t = scheduler_output.d_alpha_t + + return (d_kappa_t / (1 - kappa_t)) * (posterior - x_t) diff --git a/flow_matching/path/path.py b/flow_matching/path/path.py new file mode 100644 index 0000000000000000000000000000000000000000..c133a14ab68dfa65f32c604a4b0f6b96cd269fd5 --- /dev/null +++ b/flow_matching/path/path.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod + +from torch import Tensor + +from flow_matching.path.path_sample import PathSample + + +class ProbPath(ABC): + r"""Abstract class, representing a probability path. + + A probability path transforms the distribution :math:`p(X_0)` into :math:`p(X_1)` over :math:`t=0\rightarrow 1`. + + The ``ProbPath`` class is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives. + Here is a high-level example + + .. code-block:: python + + # Instantiate a probability path + my_path = ProbPath(...) + + for x_0, x_1 in dataset: + # Sets t to a random value in [0,1] + t = torch.rand() + + # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1) + path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t) + + # Optimizes the model. The loss function varies, depending on model and path. + loss(path_sample, my_model(x_t, t)).backward() + + """ + + @abstractmethod + def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: + r"""Sample from an abstract probability path: + + | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)`. + | returns :math:`X_0, X_1, X_t \sim p_t(X_t)`, and a conditional target :math:`Y`, all objects are under ``PathSample``. + + Args: + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). + + Returns: + PathSample: a conditional sample. + """ + + def assert_sample_shape(self, x_0: Tensor, x_1: Tensor, t: Tensor): + assert ( + t.ndim == 1 + ), f"The time vector t must have shape [batch_size]. Got {t.shape}." + assert ( + t.shape[0] == x_0.shape[0] == x_1.shape[0] + ), f"Time t dimension must match the batch size [{x_1.shape[0]}]. Got {t.shape}" diff --git a/flow_matching/path/path_sample.py b/flow_matching/path/path_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..867032e3dfe06fc5926316c9c4f133f57a025082 --- /dev/null +++ b/flow_matching/path/path_sample.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from torch import Tensor + + +@dataclass +class PathSample: + r"""Represents a sample of a conditional-flow generated probability path. + + Attributes: + x_1 (Tensor): the target sample :math:`X_1`. + x_0 (Tensor): the source sample :math:`X_0`. + t (Tensor): the time sample :math:`t`. + x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...). + dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (batch_size, ...). + + """ + + x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) + x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) + t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."}) + x_t: Tensor = field( + metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."} + ) + dx_t: Tensor = field( + metadata={"help": "conditional target dX_t, shape: (batch_size, ...)."} + ) + + +@dataclass +class DiscretePathSample: + """ + Represents a sample of a conditional-flow generated discrete probability path. + + Attributes: + x_1 (Tensor): the target sample :math:`X_1`. + x_0 (Tensor): the source sample :math:`X_0`. + t (Tensor): the time sample :math:`t`. + x_t (Tensor): the sample along the path :math:`X_t \sim p_t`. + """ + + x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) + x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) + t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."}) + x_t: Tensor = field( + metadata={"help": "samples X_t ~ p_t(X_t), shape (batch_size, ...)."} + ) diff --git a/flow_matching/path/scheduler/__init__.py b/flow_matching/path/scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b1a43f08b3c85a9c69ebe87b4fdeaf9eff29b1 --- /dev/null +++ b/flow_matching/path/scheduler/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from .schedule_transform import ScheduleTransformedModel +from .scheduler import ( + CondOTScheduler, + ConvexScheduler, + CosineScheduler, + LinearVPScheduler, + PolynomialConvexScheduler, + Scheduler, + SchedulerOutput, + VPScheduler, +) + +__all__ = [ + "CondOTScheduler", + "CosineScheduler", + "ConvexScheduler", + "PolynomialConvexScheduler", + "ScheduleTransformedModel", + "Scheduler", + "VPScheduler", + "LinearVPScheduler", + "SchedulerOutput", +] diff --git a/flow_matching/path/scheduler/schedule_transform.py b/flow_matching/path/scheduler/schedule_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..a366f19c9eafd961d5e40db28f553b409ff82520 --- /dev/null +++ b/flow_matching/path/scheduler/schedule_transform.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from torch import Tensor + +from flow_matching.path.scheduler.scheduler import Scheduler +from flow_matching.utils import ModelWrapper + + +class ScheduleTransformedModel(ModelWrapper): + """ + Change of scheduler for a velocity model. + + This class wraps a given velocity model and transforms its scheduling + to a new scheduler function. It modifies the time + dynamics of the model according to the new scheduler while maintaining + the original model's behavior. + + Example: + + .. code-block:: python + + import torch + from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel + from flow_matching.solver import ODESolver + + # Initialize the model and schedulers + model = ... + + original_scheduler = CondOTScheduler() + new_scheduler = CosineScheduler() + + # Create the transformed model + transformed_model = ScheduleTransformedModel( + velocity_model=model, + original_scheduler=original_scheduler, + new_scheduler=new_scheduler + ) + + # Set up the solver + solver = ODESolver(velocity_model=transformed_model) + + x_0 = torch.randn([10, 2]) # Example initial condition + + x_1 = solver.sample( + time_steps=torch.tensor([0.0, 1.0]), + x_init=x_0, + step_size=1/1000 + )[1] + + Args: + velocity_model (ModelWrapper): The original velocity model to be transformed. + original_scheduler (Scheduler): The scheduler used by the original model. Must implement the snr_inverse function. + new_scheduler (Scheduler): The new scheduler to be applied to the model. + """ + + def __init__( + self, + velocity_model: ModelWrapper, + original_scheduler: Scheduler, + new_scheduler: Scheduler, + ): + super().__init__(model=velocity_model) + self.original_scheduler = original_scheduler + self.new_scheduler = new_scheduler + + assert hasattr(self.original_scheduler, "snr_inverse") and callable( + getattr(self.original_scheduler, "snr_inverse") + ), "The original scheduler must have a callable 'snr_inverse' method." + + def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor: + r""" + Compute the transformed marginal velocity field for a new scheduler. + This method implements a post-training velocity scheduler change for + affine conditional flows. It transforms a generating marginal velocity + field :math:`u_t(x)` based on an original scheduler to a new marginal velocity + field :math:`\bar{u}_r(x)` based on a different scheduler, while maintaining + the same data coupling. + The transformation is based on the scale-time (ST) transformation + between the two conditional flows, defined as: + + .. math:: + + \bar{X}_r = s_r X_{t_r}, + + where :math:`X_t` and :math:`\bar{X}_r` are defined by their respective schedulers. + The ST transformation is computed as: + + .. math:: + + t_r = \rho^{-1}(\bar{\rho}(r)) \quad \text{and} \quad s_r = \frac{\bar{\sigma}_r}{\sigma_{t_r}}. + + Here, :math:`\rho(t)` is the signal-to-noise ratio (SNR) defined as: + + .. math:: + + \rho(t) = \frac{\alpha_t}{\sigma_t}. + + :math:`\bar{\rho}(r)` is similarly defined for the new scheduler. + The marginal velocity for the new scheduler is then given by: + + .. math:: + + \bar{u}_r(x) = \left(\frac{\dot{s}_r}{s_r}\right) x + s_r \dot{t}_r u_{t_r}\left(\frac{x}{s_r}\right). + + Args: + x (Tensor): :math:`x_t`, the input tensor. + t (Tensor): The time tensor (denoted as :math:`r` above). + **extras: Additional arguments for the model. + Returns: + Tensor: The transformed velocity. + """ + r = t + + r_scheduler_output = self.new_scheduler(t=r) + + alpha_r = r_scheduler_output.alpha_t + sigma_r = r_scheduler_output.sigma_t + d_alpha_r = r_scheduler_output.d_alpha_t + d_sigma_r = r_scheduler_output.d_sigma_t + + t = self.original_scheduler.snr_inverse(alpha_r / sigma_r) + + t_scheduler_output = self.original_scheduler(t=t) + + alpha_t = t_scheduler_output.alpha_t + sigma_t = t_scheduler_output.sigma_t + d_alpha_t = t_scheduler_output.d_alpha_t + d_sigma_t = t_scheduler_output.d_sigma_t + + s_r = sigma_r / sigma_t + + dt_r = ( + sigma_t + * sigma_t + * (sigma_r * d_alpha_r - alpha_r * d_sigma_r) + / (sigma_r * sigma_r * (sigma_t * d_alpha_t - alpha_t * d_sigma_t)) + ) + + ds_r = (sigma_t * d_sigma_r - sigma_r * d_sigma_t * dt_r) / (sigma_t * sigma_t) + + u_t = self.model(x=x / s_r, t=t, **extras) + u_r = ds_r * x / s_r + dt_r * s_r * u_t + + return u_r diff --git a/flow_matching/path/scheduler/scheduler.py b/flow_matching/path/scheduler/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..422618acd255bcc50779f1df60ce4f2666cea3f6 --- /dev/null +++ b/flow_matching/path/scheduler/scheduler.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field + +from typing import Union + +import torch + +from torch import Tensor + + +@dataclass +class SchedulerOutput: + r"""Represents a sample of a conditional-flow generated probability path. + + Attributes: + alpha_t (Tensor): :math:`\alpha_t`, shape (...). + sigma_t (Tensor): :math:`\sigma_t`, shape (...). + d_alpha_t (Tensor): :math:`\frac{\partial}{\partial t}\alpha_t`, shape (...). + d_sigma_t (Tensor): :math:`\frac{\partial}{\partial t}\sigma_t`, shape (...). + + """ + + alpha_t: Tensor = field(metadata={"help": "alpha_t"}) + sigma_t: Tensor = field(metadata={"help": "sigma_t"}) + d_alpha_t: Tensor = field(metadata={"help": "Derivative of alpha_t."}) + d_sigma_t: Tensor = field(metadata={"help": "Derivative of sigma_t."}) + + +class Scheduler(ABC): + """Base Scheduler class.""" + + @abstractmethod + def __call__(self, t: Tensor) -> SchedulerOutput: + r""" + Args: + t (Tensor): times in [0,1], shape (...). + + Returns: + SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t` + """ + ... + + @abstractmethod + def snr_inverse(self, snr: Tensor) -> Tensor: + r""" + Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`. + + Args: + snr (Tensor): The signal-to-noise, shape (...) + + Returns: + Tensor: t, shape (...) + """ + ... + + +class ConvexScheduler(Scheduler): + @abstractmethod + def __call__(self, t: Tensor) -> SchedulerOutput: + """Scheduler for convex paths. + + Args: + t (Tensor): times in [0,1], shape (...). + + Returns: + SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t` + """ + ... + + @abstractmethod + def kappa_inverse(self, kappa: Tensor) -> Tensor: + """ + Computes :math:`t` from :math:`\kappa_t`. + + Args: + kappa (Tensor): :math:`\kappa`, shape (...) + + Returns: + Tensor: t, shape (...) + """ + ... + + def snr_inverse(self, snr: Tensor) -> Tensor: + r""" + Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`. + + Args: + snr (Tensor): The signal-to-noise, shape (...) + + Returns: + Tensor: t, shape (...) + """ + kappa_t = snr / (1.0 + snr) + + return self.kappa_inverse(kappa=kappa_t) + + +class CondOTScheduler(ConvexScheduler): + """CondOT Scheduler.""" + + def __call__(self, t: Tensor) -> SchedulerOutput: + return SchedulerOutput( + alpha_t=t, + sigma_t=1 - t, + d_alpha_t=torch.ones_like(t), + d_sigma_t=-torch.ones_like(t), + ) + + def kappa_inverse(self, kappa: Tensor) -> Tensor: + return kappa + + +class PolynomialConvexScheduler(ConvexScheduler): + """Polynomial Scheduler.""" + + def __init__(self, n: Union[float, int]) -> None: + assert isinstance( + n, (float, int) + ), f"`n` must be a float or int. Got {type(n)=}." + assert n > 0, f"`n` must be positive. Got {n=}." + + self.n = n + + def __call__(self, t: Tensor) -> SchedulerOutput: + return SchedulerOutput( + alpha_t=t**self.n, + sigma_t=1 - t**self.n, + d_alpha_t=self.n * (t ** (self.n - 1)), + d_sigma_t=-self.n * (t ** (self.n - 1)), + ) + + def kappa_inverse(self, kappa: Tensor) -> Tensor: + return torch.pow(kappa, 1.0 / self.n) + + +class VPScheduler(Scheduler): + """Variance Preserving Scheduler.""" + + def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0) -> None: + self.beta_min = beta_min + self.beta_max = beta_max + super().__init__() + + def __call__(self, t: Tensor) -> SchedulerOutput: + b = self.beta_min + B = self.beta_max + T = 0.5 * (1 - t) ** 2 * (B - b) + (1 - t) * b + dT = -(1 - t) * (B - b) - b + + return SchedulerOutput( + alpha_t=torch.exp(-0.5 * T), + sigma_t=torch.sqrt(1 - torch.exp(-T)), + d_alpha_t=-0.5 * dT * torch.exp(-0.5 * T), + d_sigma_t=0.5 * dT * torch.exp(-T) / torch.sqrt(1 - torch.exp(-T)), + ) + + def snr_inverse(self, snr: Tensor) -> Tensor: + T = -torch.log(snr**2 / (snr**2 + 1)) + b = self.beta_min + B = self.beta_max + t = 1 - ((-b + torch.sqrt(b**2 + 2 * (B - b) * T)) / (B - b)) + return t + + +class LinearVPScheduler(Scheduler): + """Linear Variance Preserving Scheduler.""" + + def __call__(self, t: Tensor) -> SchedulerOutput: + return SchedulerOutput( + alpha_t=t, + sigma_t=(1 - t**2) ** 0.5, + d_alpha_t=torch.ones_like(t), + d_sigma_t=-t / (1 - t**2) ** 0.5, + ) + + def snr_inverse(self, snr: Tensor) -> Tensor: + return torch.sqrt(snr**2 / (1 + snr**2)) + + +class CosineScheduler(Scheduler): + """Cosine Scheduler.""" + + def __call__(self, t: Tensor) -> SchedulerOutput: + pi = torch.pi + return SchedulerOutput( + alpha_t=torch.sin(pi / 2 * t), + sigma_t=torch.cos(pi / 2 * t), + d_alpha_t=pi / 2 * torch.cos(pi / 2 * t), + d_sigma_t=-pi / 2 * torch.sin(pi / 2 * t), + ) + + def snr_inverse(self, snr: Tensor) -> Tensor: + return 2.0 * torch.atan(snr) / torch.pi diff --git a/flow_matching/solver/__init__.py b/flow_matching/solver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd7b01c0992b5469d35885961f81da1bc9fe745 --- /dev/null +++ b/flow_matching/solver/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from .discrete_solver import MixtureDiscreteEulerSolver +from .ode_solver import ODESolver +from .riemannian_ode_solver import RiemannianODESolver +from .solver import Solver + +__all__ = [ + "ODESolver", + "Solver", + "ModelWrapper", + "MixtureDiscreteEulerSolver", + "RiemannianODESolver", +] diff --git a/flow_matching/solver/discrete_solver.py b/flow_matching/solver/discrete_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2e4f094eb620d9be666ade244910425f31de4b --- /dev/null +++ b/flow_matching/solver/discrete_solver.py @@ -0,0 +1,428 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from contextlib import nullcontext +from math import ceil +from typing import Callable, Optional, Union + +import torch +from torch import Tensor +import gc +from torch.nn import functional as F + +from flow_matching.path import MixtureDiscreteProbPath + +from flow_matching.solver.solver import Solver +from flow_matching.utils import categorical, ModelWrapper +from .utils import get_nearest_times +from ..utils.multi_guidance import * + +try: + from tqdm import tqdm + + TQDM_AVAILABLE = True +except ImportError: + TQDM_AVAILABLE = False + + +class MixtureDiscreteEulerSolver(Solver): + r"""Solver that simulates the CTMC process :math:`(X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}` defined by :math:`p_t` the marginal probability path of ``path``. + Given :math:`X_t \sim p_t`, the algorithm of solver step from :math:`t` to :math:`t+h` for the i-th coordinate is: + + .. math:: + + \begin{align*} + & X_1^i \sim p_{1|t}^i(\cdot|X_t)\\ + & \lambda^i \gets \sum_{x^i\ne X_t^i} u_t^i(x^i, X_t^i|X_1^i)\\ + & Z^i_{\text{change}} \sim U[0,1]\\ + & X_{t+h}^i \sim \begin{cases} + \frac{u_t^i(\cdot, X_t^i|X_1^i)}{\lambda^i}(1-\delta_{X_t^i}(\cdot)) \text{ if $Z^i_{\text{change}}\le 1-e^{-h\lambda^i}$}\\ + \delta_{X_t^i}(\cdot) \text{ else } + \end{cases} + \end{align*} + + Where :math:`p_{1|t}(\cdot|X_t)` is the output of ``model``, and the conditional probability velocity is of the mixture probability path is: + + .. math:: + + u_t^i(x^i, y^i|x_1^i) = \hat{u}_t^i(x^i, y^i|x_1^i) + c_{\text{div\_free}}\left[\hat{u}_t^i(x^i, y^i|x_1^i) - \check{u}_t^i(x^i, y^i|x_1^i) \right], + + where + + .. math:: + \hat{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{1-\kappa_t} \left[ \delta_{x_1^i}(x^i) - \delta_{y^i}(x^i) \right], + + and + + .. math:: + + \check{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{\kappa_t}\left[ \delta_{y^i}(x^i) - p(x^i) \right]. + + The source distribution :math:`p(x^i)` is given by ``p``. + + Args: + model (ModelWrapper): trained with x-prediction, outputting posterior probabilities (in the range :math:`[0,1]`), output must be [..., vocabulary_size]. + path (MixtureDiscreteProbPath): Probability path used for x-prediction training. + vocabulary_size (int): size of the discrete vocabulary. + source_distribution_p (Optional[Tensor], optional): Source distribution, must be of shape [vocabulary_size]. Required only when divergence-free term for the probability velocity is non-zero. Defaults to None. + """ + + def __init__( + self, + model: ModelWrapper, + path: MixtureDiscreteProbPath, + vocabulary_size: int, + source_distribution_p: Optional[Tensor] = None, + ): + super().__init__() + self.model = model + self.path = path + self.vocabulary_size = vocabulary_size + + if source_distribution_p is not None: + assert source_distribution_p.shape == torch.Size( + [vocabulary_size] + ), f"Source distribution p dimension must match the vocabulary size {vocabulary_size}. Got {source_distribution_p.shape}." + + self.source_distribution_p = source_distribution_p + + @torch.no_grad() + def sample( + self, + x_init: Tensor, + step_size: Optional[float], + div_free: Union[float, Callable[[float], float]] = 0.0, + dtype_categorical: torch.dtype = torch.float32, + time_grid: Tensor = torch.tensor([0.0, 1.0]), + return_intermediates: bool = False, + verbose: bool = False, + **model_extras, + ) -> Tensor: + """ + Sample a sequence of discrete values from the given model. + + .. code-block:: python + + import torch + from flow_matching.utils import ModelWrapper + from flow_matching.solver import MixtureDiscreteEulerSolver + + class DummyModel(ModelWrapper): + def __init__(self): + super().__init__(None) + def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: + return ... + + model = DummyModel() + solver = MixtureDiscreteEulerSolver(model=model) + + x_init = torch.LongTensor([122, 725]) + step_size = 0.001 + time_grid = torch.tensor([0.0, 1.0]) + + result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid) + + Args: + x_init (Tensor): The initial state. + step_size (Optional[float]): If float then time discretization is uniform with the given step size. If None then time discretization is set to be time_grid. + div_free (Union[float, Callable[[float], float]]): The coefficient of the divergence-free term in the probability velocity. Can be either a float or a time dependent function. Defaults to 0.0. + dtype_categorical (torch.dtype): Precision to use for categorical sampler. Defaults to torch.float32. + time_grid (Tensor): The CTMC process is solved in the interval [time_grid[0], time_grid[-1]] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]). + return_intermediates (bool): If True then return intermediate time steps according to time_grid. Defaults to False. + verbose (bool): Whether to print progress bars. Defaults to False. + **model_extras: Additional input for the model. + + Returns: + Tensor: The sampled sequence of discrete values. + + Raises: + ImportError: To run in verbose mode, tqdm must be installed. + """ + if not div_free == 0.0: + assert ( + self.source_distribution_p is not None + ), "Source distribution p must be specified in order to add a divergence-free term to the probability velocity." + + # Initialize the current state `x_t` with the initial state `X_0`. + time_grid = time_grid.to(device=x_init.device) + + if step_size is None: + # If step_size is None then set the t discretization to time_grid. + t_discretization = time_grid + n_steps = len(time_grid) - 1 + else: + # If step_size is float then t discretization is uniform with step size set by step_size. + t_init = time_grid[0].item() + t_final = time_grid[-1].item() + assert ( + t_final - t_init + ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}." + + n_steps = ceil((t_final - t_init) / step_size) + t_discretization = torch.tensor( + [t_init + step_size * i for i in range(n_steps)] + [t_final], + device=x_init.device, + ) + + if return_intermediates: + # get order of intermediate steps: + order = torch.argsort(time_grid) + # Compute intermediate steps to return via nearest points in t_discretization to time_grid. + time_grid = get_nearest_times( + time_grid=time_grid, t_discretization=t_discretization + ) + + x_t = x_init.clone() + steps_counter = 0 + res = [] + + if return_intermediates: + res = [x_init.clone()] + + if verbose: + if not TQDM_AVAILABLE: + raise ImportError( + "tqdm is required for verbose mode. Please install it." + ) + ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}") + else: + ctx = nullcontext() + + with ctx: + for i in range(n_steps): + t = t_discretization[i : i + 1] + h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1] + + # Sample x_1 ~ p_1|t( \cdot |x_t) + p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras) + x_1 = categorical(p_1t.to(dtype=dtype_categorical)) + + # Checks if final step + if i == n_steps - 1: + x_t = x_1 + else: + # Compute u_t(x|x_t,x_1) + scheduler_output = self.path.scheduler(t=t) + + k_t = scheduler_output.alpha_t + d_k_t = scheduler_output.d_alpha_t + + delta_1 = F.one_hot(x_1, num_classes=self.vocabulary_size).to( + k_t.dtype + ) # [B, L, V] + u = d_k_t / (1 - k_t) * delta_1 + + # Add divergence-free part + div_free_t = div_free(t) if callable(div_free) else div_free + + if div_free_t > 0: + p_0 = self.source_distribution_p[(None,) * x_t.dim()] + u = u + div_free_t * d_k_t / (k_t * (1 - k_t)) * ( + (1 - k_t) * p_0 + k_t * delta_1 + ) + + # Set u_t(x_t|x_t,x_1) = 0 + delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size) # [B, L, V] + u = torch.where( + delta_t.to(dtype=torch.bool), torch.zeros_like(u), u + ) + # import pdb + # if i % 10 == 0: + # pdb.set_trace() + # Sample x_t ~ u_t( \cdot |x_t,x_1) + intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0 + mask_jump = torch.rand(size=x_t.shape, device=x_t.device) < 1 - torch.exp(-h * intensity) + + if mask_jump.sum() > 0: + x_t[mask_jump] = categorical( + u[mask_jump].to(dtype=dtype_categorical) + ) + + steps_counter += 1 + t = t + h + + if return_intermediates and (t in time_grid): + res.append(x_t.clone()) + + if verbose: + ctx.n = t.item() + ctx.refresh() + ctx.set_description(f"NFE: {steps_counter}") + + if return_intermediates: + if step_size is None: + return torch.stack(res, dim=0) + else: + return torch.stack(res, dim=0)[order] + else: + return x_t + + + @torch.no_grad() + def multi_guidance_sample( + self, + args, + x_init: Tensor, + step_size: Optional[float], + div_free: Union[float, Callable[[float], float]] = 0.0, + dtype_categorical: torch.dtype = torch.float32, + time_grid: Tensor = torch.tensor([0.0, 1.0]), + return_intermediates: bool = False, + verbose: bool = False, + score_models: list = None, + num_objectives: int = 1, + weights: list = None, + **model_extras, + ) -> Tensor: + + # score_list_0 = [] + # score_list_1 = [] + # score_list_2 = [] + # score_list_3 = [] + # score_list_4 = [] + # score_list_5 = [] + + import pdb + + if not div_free == 0.0: + raise NotImplementedError + + # Initialize the current state `x_t` with the initial state `X_0`. + time_grid = time_grid.to(device=x_init.device) + + if step_size is None: + # If step_size is None then set the t discretization to time_grid. + t_discretization = time_grid + n_steps = len(time_grid) - 1 + else: + # If step_size is float then t discretization is uniform with step size set by step_size. + t_init = time_grid[0].item() + t_final = time_grid[-1].item() + assert ( + t_final - t_init + ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}." + + n_steps = ceil((t_final - t_init) / step_size) + t_discretization = torch.tensor( + [t_init + step_size * i for i in range(n_steps)] + [t_final], + device=x_init.device, + ) + + if return_intermediates: + # get order of intermediate steps: + order = torch.argsort(time_grid) + # Compute intermediate steps to return via nearest points in t_discretization to time_grid. + time_grid = get_nearest_times( + time_grid=time_grid, t_discretization=t_discretization + ) + + x_t = x_init.clone() + steps_counter = 0 + res = [] + + if return_intermediates: + res = [x_init.clone()] + + if verbose: + if not TQDM_AVAILABLE: + raise ImportError( + "tqdm is required for verbose mode. Please install it." + ) + ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}") + else: + ctx = nullcontext() + + # Randomly sample a weight vector + if weights is not None: + w = torch.tensor(weights).to(device=x_init.device) + else: + w, _ = select_random_weight_vector(num_objectives, args.num_div) + # w = torch.tensor([0.2, 0.7, 0.05, 0.05]).to(x_t.device) + w = w.to(device=x_init.device) + print(f"Weight Vector: {w}") + Phi = args.Phi_init + ema_r_t = None + + with ctx: + for i in range(n_steps): + t = t_discretization[i : i + 1] + h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1] + + p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras) + x_1 = categorical(p_1t.to(dtype=dtype_categorical)) + + # Checks if final step + if i != n_steps - 1: + # Compute u_t(y,x) + scheduler_output = self.path.scheduler(t=t) + k_t = scheduler_output.alpha_t + d_k_t = scheduler_output.d_alpha_t + u_t = d_k_t / (1 - k_t) * p_1t + + guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S = guided_transition_scoring(x_t, u_t, w, score_models, t, w, args) + + best_candidate, accepted_mask, valid_mask, Phi, ema_r_t = adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=ema_r_t) + + # best_candidate, accepted_mask, valid_mask, Phi, ema_r_t = hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=ema_r_t) + + # best_candidate = get_best_candidate(improvement_values, cand_tokens, delta_S) + + x_t = euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h) + + + steps_counter += 1 + t = t + h + + scores = [] + for i, s in enumerate(score_models): + sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) + if 't' in sig.parameters: + candidate_scores = s(x_t, 1) + else: + candidate_scores = s(x_t) + + if isinstance(candidate_scores, tuple): + for score in candidate_scores: + scores.append(score.item()) + else: + scores.append(candidate_scores.item()) + print(scores) + + # print(f"Score {i}: {[round(s.item(), 4) for s in candidate_scores]}") + # if i == 0: + # score_list_0.append(round(candidate_scores[0].item(), 2)) + # # score_list_0.append(round(1-candidate_scores.item(), 2)) + # # score_list_1.append(round(candidate_scores[1].item(), 2)) + # if i == 1: + # score_list_1.append(round(candidate_scores.item(), 2)) + # # score_list_2.append(round(candidate_scores.item(), 2)) + # if i == 2: + # score_list_2.append(round(candidate_scores.item(), 2)) + # if i == 3: + # score_list_3.append(round(candidate_scores.item(), 2)) + # if i == 4: + # score_list_4.append(round(candidate_scores.item(), 2)) + # if i == 5: + # score_list_5.append(round(candidate_scores.item(), 2)) + + + if return_intermediates and (t in time_grid): + res.append(x_t.clone()) + + if verbose: + ctx.n = t.item() + ctx.refresh() + ctx.set_description(f"NFE: {steps_counter}") + + # print(score_list) + if return_intermediates: + if step_size is None: + return torch.stack(res, dim=0) + else: + return torch.stack(res, dim=0)[order] + else: + # return x_t, score_list_0, score_list_1, score_list_2, score_list_3, score_list_4, score_list_5 + return x_t \ No newline at end of file diff --git a/flow_matching/solver/ode_solver.py b/flow_matching/solver/ode_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..89975064673727d0376c5cb276225b4fb8dff9cb --- /dev/null +++ b/flow_matching/solver/ode_solver.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional, Sequence, Tuple, Union + +import torch +from torch import Tensor +from torchdiffeq import odeint + +from flow_matching.solver.solver import Solver +from flow_matching.utils import gradient, ModelWrapper + + +class ODESolver(Solver): + """A class to solve ordinary differential equations (ODEs) using a specified velocity model. + + This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers. + + Args: + velocity_model (Union[ModelWrapper, Callable]): a velocity field model receiving :math:`(x,t)` and returning :math:`u_t(x)` + """ + + def __init__(self, velocity_model: Union[ModelWrapper, Callable]): + super().__init__() + self.velocity_model = velocity_model + + def sample( + self, + x_init: Tensor, + step_size: Optional[float], + method: str = "euler", + atol: float = 1e-5, + rtol: float = 1e-5, + time_grid: Tensor = torch.tensor([0.0, 1.0]), + return_intermediates: bool = False, + enable_grad: bool = False, + **model_extras, + ) -> Union[Tensor, Sequence[Tensor]]: + r"""Solve the ODE with the velocity field. + + Example: + + .. code-block:: python + + import torch + from flow_matching.utils import ModelWrapper + from flow_matching.solver import ODESolver + + class DummyModel(ModelWrapper): + def __init__(self): + super().__init__(None) + + def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: + return torch.ones_like(x) * 3.0 * t**2 + + velocity_model = DummyModel() + solver = ODESolver(velocity_model=velocity_model) + x_init = torch.tensor([0.0, 0.0]) + step_size = 0.001 + time_grid = torch.tensor([0.0, 1.0]) + + result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid) + + Args: + x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`). Shape: [batch_size, ...]. + step_size (Optional[float]): The step size. Must be None for adaptive step solvers. + method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq. + atol (float): Absolute tolerance, used for adaptive step solvers. + rtol (float): Relative tolerance, used for adaptive step solvers. + time_grid (Tensor): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]). + return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False. + enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False. + **model_extras: Additional input for the model. + + Returns: + Union[Tensor, Sequence[Tensor]]: The last timestep when return_intermediates=False, otherwise all values specified in time_grid. + """ + + time_grid = time_grid.to(x_init.device) + + def ode_func(t, x): + return self.velocity_model(x=x, t=t, **model_extras) + + ode_opts = {"step_size": step_size} if step_size is not None else {} + + with torch.set_grad_enabled(enable_grad): + # Approximate ODE solution with numerical ODE solver + sol = odeint( + ode_func, + x_init, + time_grid, + method=method, + options=ode_opts, + atol=atol, + rtol=rtol, + ) + + if return_intermediates: + return sol + else: + return sol[-1] + + def compute_likelihood( + self, + x_1: Tensor, + log_p0: Callable[[Tensor], Tensor], + step_size: Optional[float], + method: str = "euler", + atol: float = 1e-5, + rtol: float = 1e-5, + time_grid: Tensor = torch.tensor([1.0, 0.0]), + return_intermediates: bool = False, + exact_divergence: bool = False, + enable_grad: bool = False, + **model_extras, + ) -> Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]: + r"""Solve for log likelihood given a target sample at :math:`t=0`. + + Works similarly to sample, but solves the ODE in reverse to compute the log-likelihood. The velocity model must be differentiable with respect to x. + The function assumes log_p0 is the log probability of the source distribution at :math:`t=0`. + + Args: + x_1 (Tensor): target sample (e.g., samples :math:`X_1 \sim p_1`). + log_p0 (Callable[[Tensor], Tensor]): Log probability function of the source distribution. + step_size (Optional[float]): The step size. Must be None for adaptive step solvers. + method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq. + atol (float): Absolute tolerance, used for adaptive step solvers. + rtol (float): Relative tolerance, used for adaptive step solvers. + time_grid (Tensor): If step_size is None then time discretization is set by the time grid. Must start at 1.0 and end at 0.0, otherwise the likelihood computation is not valid. Defaults to torch.tensor([1.0, 0.0]). + return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Otherwise only return the final sample. Defaults to False. + exact_divergence (bool): Whether to compute the exact divergence or use the Hutchinson estimator. + enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False. + **model_extras: Additional input for the model. + + Returns: + Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]: Samples at time_grid and log likelihood values of given x_1. + """ + assert ( + time_grid[0] == 1.0 and time_grid[-1] == 0.0 + ), f"Time grid must start at 1.0 and end at 0.0. Got {time_grid}" + + # Fix the random projection for the Hutchinson divergence estimator + if not exact_divergence: + z = (torch.randn_like(x_1).to(x_1.device) < 0) * 2.0 - 1.0 + + def ode_func(x, t): + return self.velocity_model(x=x, t=t, **model_extras) + + def dynamics_func(t, states): + xt = states[0] + with torch.set_grad_enabled(True): + xt.requires_grad_() + ut = ode_func(xt, t) + + if exact_divergence: + # Compute exact divergence + div = 0 + for i in range(ut.flatten(1).shape[1]): + div += gradient(ut[:, i], xt, create_graph=True)[:, i] + else: + # Compute Hutchinson divergence estimator E[z^T D_x(ut) z] + ut_dot_z = torch.einsum( + "ij,ij->i", ut.flatten(start_dim=1), z.flatten(start_dim=1) + ) + grad_ut_dot_z = gradient(ut_dot_z, xt) + div = torch.einsum( + "ij,ij->i", + grad_ut_dot_z.flatten(start_dim=1), + z.flatten(start_dim=1), + ) + + return ut.detach(), div.detach() + + y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device)) + ode_opts = {"step_size": step_size} if step_size is not None else {} + + with torch.set_grad_enabled(enable_grad): + sol, log_det = odeint( + dynamics_func, + y_init, + time_grid, + method=method, + options=ode_opts, + atol=atol, + rtol=rtol, + ) + + x_source = sol[-1] + source_log_p = log_p0(x_source) + + if return_intermediates: + return sol, source_log_p + log_det[-1] + else: + return sol[-1], source_log_p + log_det[-1] diff --git a/flow_matching/solver/riemannian_ode_solver.py b/flow_matching/solver/riemannian_ode_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ff0fd84b0eb27e25e4b28fabb56c63b3ca5602 --- /dev/null +++ b/flow_matching/solver/riemannian_ode_solver.py @@ -0,0 +1,261 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Callable + +import torch +from torch import Tensor + +from flow_matching.solver.solver import Solver +from flow_matching.utils import ModelWrapper +from flow_matching.utils.manifolds import geodesic, Manifold + +try: + from tqdm import tqdm + + TQDM_AVAILABLE = True +except ImportError: + TQDM_AVAILABLE = False + + +class RiemannianODESolver(Solver): + r"""Riemannian ODE solver + Initialize the ``RiemannianODESolver``. + + Args: + manifold (Manifold): the manifold to solve on. + velocity_model (ModelWrapper): a velocity field model receiving :math:`(x,t)` + and returning :math:`u_t(x)` which is assumed to lie on the tangent plane at `x`. + """ + + def __init__(self, manifold: Manifold, velocity_model: ModelWrapper): + super().__init__() + self.manifold = manifold + self.velocity_model = velocity_model + + def sample( + self, + x_init: Tensor, + step_size: float, + projx: bool = True, + proju: bool = True, + method: str = "euler", + time_grid: Tensor = torch.tensor([0.0, 1.0]), + return_intermediates: bool = False, + verbose: bool = False, + enable_grad: bool = False, + **model_extras, + ) -> Tensor: + r"""Solve the ODE with the `velocity_field` on the manifold. + + Args: + x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`). + step_size (float): The step size. + projx (bool): Whether to project the point onto the manifold at each step. Defaults to True. + proju (bool): Whether to project the vector field onto the tangent plane at each step. Defaults to True. + method (str): One of ["euler", "midpoint", "rk4"]. Defaults to "euler". + time_grid (Tensor, optional): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]). + return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False. + verbose (bool, optional): Whether to print progress bars. Defaults to False. + enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False. + **model_extras: Additional input for the model. + + Returns: + Tensor: The sampled sequence. Defaults to returning samples at :math:`t=1`. + + Raises: + ImportError: To run in verbose mode, tqdm must be installed. + """ + step_fns = { + "euler": _euler_step, + "midpoint": _midpoint_step, + "rk4": _rk4_step, + } + assert method in step_fns.keys(), f"Unknown method {method}" + step_fn = step_fns[method] + + def velocity_func(x, t): + return self.velocity_model(x=x, t=t, **model_extras) + + # --- Factor this out. + time_grid = torch.sort(time_grid.to(device=x_init.device)).values + + if step_size is None: + # If step_size is None then set the t discretization to time_grid. + t_discretization = time_grid + n_steps = len(time_grid) - 1 + else: + # If step_size is float then t discretization is uniform with step size set by step_size. + t_init = time_grid[0].item() + t_final = time_grid[-1].item() + assert ( + t_final - t_init + ) > step_size, f"Time interval [min(time_grid), max(time_grid)] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}." + + n_steps = math.ceil((t_final - t_init) / step_size) + t_discretization = torch.tensor( + [step_size * i for i in range(n_steps)] + [t_final], + device=x_init.device, + ) + # --- + t0s = t_discretization[:-1] + + if verbose: + if not TQDM_AVAILABLE: + raise ImportError( + "tqdm is required for verbose mode. Please install it." + ) + t0s = tqdm(t0s) + + if return_intermediates: + xts = [] + i_ret = 0 + + with torch.set_grad_enabled(enable_grad): + xt = x_init + for t0, t1 in zip(t0s, t_discretization[1:]): + dt = t1 - t0 + xt_next = step_fn( + velocity_func, + xt, + t0, + dt, + manifold=self.manifold, + projx=projx, + proju=proju, + ) + if return_intermediates: + while ( + i_ret < len(time_grid) + and t0 <= time_grid[i_ret] + and time_grid[i_ret] <= t1 + ): + xts.append( + interp(self.manifold, xt, xt_next, t0, t1, time_grid[i_ret]) + ) + i_ret += 1 + xt = xt_next + + if return_intermediates: + return torch.stack(xts, dim=0) + else: + return xt + + +def interp(manifold, xt, xt_next, t, t_next, t_ret): + return geodesic(manifold, xt, xt_next)( + (t_ret - t) / (t_next - t).reshape(1) + ).reshape_as(xt) + + +def _euler_step( + velocity_model: Callable, + xt: Tensor, + t0: Tensor, + dt: Tensor, + manifold: Manifold, + projx: bool = True, + proju: bool = True, +) -> Tensor: + r"""Perform an Euler step on a manifold. + + Args: + velocity_model (Callable): the velocity model + xt (Tensor): tensor containing the state at time t0 + t0 (Tensor): the time at which this step is taken + dt (Tensor): the step size + manifold (Manifold): a manifold object + projx (bool, optional): whether to project the state onto the manifold. Defaults to True. + proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True. + + Returns: + Tensor: tensor containing the state after the step + """ + velocity_fn = lambda x, t: ( + manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t) + ) + projx_fn = lambda x: manifold.projx(x) if projx else x + + vt = velocity_fn(xt, t0) + + xt = xt + dt * vt + + return projx_fn(xt) + + +def _midpoint_step( + velocity_model: Callable, + xt: Tensor, + t0: Tensor, + dt: Tensor, + manifold: Manifold, + projx: bool = True, + proju: bool = True, +) -> Tensor: + r"""Perform a midpoint step on a manifold. + + Args: + velocity_model (Callable): the velocity model + xt (Tensor): tensor containing the state at time t0 + t0 (Tensor): the time at which this step is taken + dt (Tensor): the step size + manifold (Manifold): a manifold object + projx (bool, optional): whether to project the state onto the manifold. Defaults to True. + proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True. + + Returns: + Tensor: tensor containing the state after the step + """ + velocity_fn = lambda x, t: ( + manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t) + ) + projx_fn = lambda x: manifold.projx(x) if projx else x + + half_dt = 0.5 * dt + vt = velocity_fn(xt, t0) + x_mid = xt + half_dt * vt + x_mid = projx_fn(x_mid) + + xt = xt + dt * velocity_fn(x_mid, t0 + half_dt) + + return projx_fn(xt) + + +def _rk4_step( + velocity_model: Callable, + xt: Tensor, + t0: Tensor, + dt: Tensor, + manifold: Manifold, + projx: bool = True, + proju: bool = True, +) -> Tensor: + r"""Perform an RK4 step on a manifold. + + Args: + velocity_model (Callable): the velocity model + xt (Tensor): tensor containing the state at time t0 + t0 (Tensor): the time at which this step is taken + dt (Tensor): the step size + manifold (Manifold): a manifold object + projx (bool, optional): whether to project the state onto the manifold. Defaults to True. + proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True. + + Returns: + Tensor: tensor containing the state after the step + """ + velocity_fn = lambda x, t: ( + manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t) + ) + projx_fn = lambda x: manifold.projx(x) if projx else x + + k1 = velocity_fn(xt, t0) + k2 = velocity_fn(projx_fn(xt + dt * k1 / 3), t0 + dt / 3) + k3 = velocity_fn(projx_fn(xt + dt * (k2 - k1 / 3)), t0 + dt * 2 / 3) + k4 = velocity_fn(projx_fn(xt + dt * (k1 - k2 + k3)), t0 + dt) + + return projx_fn(xt + (k1 + 3 * (k2 + k3) + k4) * dt * 0.125) diff --git a/flow_matching/solver/solver.py b/flow_matching/solver/solver.py new file mode 100644 index 0000000000000000000000000000000000000000..4819e1c0ff148c411648c6e248380f471efd0a6c --- /dev/null +++ b/flow_matching/solver/solver.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod + +from torch import nn, Tensor + + +class Solver(ABC, nn.Module): + """Abstract base class for solvers.""" + + @abstractmethod + def sample(self, x_0: Tensor = None) -> Tensor: + ... diff --git a/flow_matching/solver/utils.py b/flow_matching/solver/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a34ee4b0ca4c3ef64dfa1b55cb13da1187b965 --- /dev/null +++ b/flow_matching/solver/utils.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor + + +def get_nearest_times(time_grid: Tensor, t_discretization: Tensor) -> Tensor: + distances = torch.cdist( + time_grid.unsqueeze(1), + t_discretization.unsqueeze(1), + compute_mode="donot_use_mm_for_euclid_dist", + ) + nearest_indices = distances.argmin(dim=1) + + return t_discretization[nearest_indices] diff --git a/flow_matching/utils/__init__.py b/flow_matching/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0085c44c5e044d5a7dd0c2bde9755e9d3c18c0b1 --- /dev/null +++ b/flow_matching/utils/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from .categorical_sampler import categorical +from .model_wrapper import ModelWrapper +from .utils import expand_tensor_like, gradient, unsqueeze_to_match + +__all__ = [ + "unsqueeze_to_match", + "expand_tensor_like", + "gradient", + "categorical", + "ModelWrapper", +] diff --git a/flow_matching/utils/categorical_sampler.py b/flow_matching/utils/categorical_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..70937af5af1da4c3c2b18a85745197e240724a96 --- /dev/null +++ b/flow_matching/utils/categorical_sampler.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor + + +def categorical(probs: Tensor) -> Tensor: + r"""Categorical sampler according to weights in the last dimension of ``probs`` using :func:`torch.multinomial`. + + Args: + probs (Tensor): probabilities. + + Returns: + Tensor: Samples. + """ + + return torch.multinomial(probs.flatten(0, -2), 1, replacement=True).view( + *probs.shape[:-1] + ) diff --git a/flow_matching/utils/manifolds/__init__.py b/flow_matching/utils/manifolds/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1148872f29c5e42e2578eab31751abaf6def51da --- /dev/null +++ b/flow_matching/utils/manifolds/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from .manifold import Euclidean, Manifold +from .sphere import Sphere +from .torus import FlatTorus +from .utils import geodesic + +__all__ = [ + "Euclidean", + "Manifold", + "Sphere", + "FlatTorus", + "geodesic", +] diff --git a/flow_matching/utils/manifolds/manifold.py b/flow_matching/utils/manifolds/manifold.py new file mode 100644 index 0000000000000000000000000000000000000000..52a6a1bc6ea7d83e3aeefd335edb19a842aeca21 --- /dev/null +++ b/flow_matching/utils/manifolds/manifold.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import abc + +import torch.nn as nn +from torch import Tensor + + +class Manifold(nn.Module, metaclass=abc.ABCMeta): + """A manifold class that contains projection operations and logarithm and exponential maps.""" + + @abc.abstractmethod + def expmap(self, x: Tensor, u: Tensor) -> Tensor: + r"""Computes exponential map :math:`\exp_x(u)`. + + Args: + x (Tensor): point on the manifold + u (Tensor): tangent vector at point :math:`x` + + Raises: + NotImplementedError: if not implemented + + Returns: + Tensor: transported point + """ + raise NotImplementedError + + @abc.abstractmethod + def logmap(self, x: Tensor, y: Tensor) -> Tensor: + r"""Computes logarithmic map :math:`\log_x(y)`. + + Args: + x (Tensor): point on the manifold + y (Tensor): point on the manifold + + Raises: + NotImplementedError: if not implemented + + Returns: + Tensor: tangent vector at point :math:`x` + """ + raise NotImplementedError + + @abc.abstractmethod + def projx(self, x: Tensor) -> Tensor: + """Project point :math:`x` on the manifold. + + Args: + x (Tensor): point to be projected + + Raises: + NotImplementedError: if not implemented + + Returns: + Tensor: projected point on the manifold + """ + raise NotImplementedError + + @abc.abstractmethod + def proju(self, x: Tensor, u: Tensor) -> Tensor: + """Project vector :math:`u` on a tangent space for :math:`x`. + + Args: + x (Tensor): point on the manifold + u (Tensor): vector to be projected + + Raises: + NotImplementedError: if not implemented + + Returns: + Tensor: projected tangent vector + """ + raise NotImplementedError + + +class Euclidean(Manifold): + """The Euclidean manifold.""" + + def expmap(self, x: Tensor, u: Tensor) -> Tensor: + return x + u + + def logmap(self, x: Tensor, y: Tensor) -> Tensor: + return y - x + + def projx(self, x: Tensor) -> Tensor: + return x + + def proju(self, x: Tensor, u: Tensor) -> Tensor: + return u diff --git a/flow_matching/utils/manifolds/sphere.py b/flow_matching/utils/manifolds/sphere.py new file mode 100644 index 0000000000000000000000000000000000000000..76bf748f54b055b10cd4c4a085bf8204df6f47a0 --- /dev/null +++ b/flow_matching/utils/manifolds/sphere.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor + +from flow_matching.utils.manifolds import Manifold + + +class Sphere(Manifold): + """Represents a hyperpshere in :math:`R^D`. Isometric to the product of 1-D spheres.""" + + EPS = {torch.float32: 1e-4, torch.float64: 1e-7} + + def expmap(self, x: Tensor, u: Tensor) -> Tensor: + norm_u = u.norm(dim=-1, keepdim=True) + exp = x * torch.cos(norm_u) + u * torch.sin(norm_u) / norm_u + retr = self.projx(x + u) + cond = norm_u > self.EPS[norm_u.dtype] + + return torch.where(cond, exp, retr) + + def logmap(self, x: Tensor, y: Tensor) -> Tensor: + u = self.proju(x, y - x) + dist = self.dist(x, y, keepdim=True) + cond = dist.gt(self.EPS[x.dtype]) + result = torch.where( + cond, + u * dist / u.norm(dim=-1, keepdim=True).clamp_min(self.EPS[x.dtype]), + u, + ) + return result + + def projx(self, x: Tensor) -> Tensor: + return x / x.norm(dim=-1, keepdim=True) + + def proju(self, x: Tensor, u: Tensor) -> Tensor: + return u - (x * u).sum(dim=-1, keepdim=True) * x + + def dist(self, x: Tensor, y: Tensor, *, keepdim=False) -> Tensor: + inner = (x * y).sum(-1, keepdim=keepdim) + return torch.acos(inner) diff --git a/flow_matching/utils/manifolds/torus.py b/flow_matching/utils/manifolds/torus.py new file mode 100644 index 0000000000000000000000000000000000000000..3587ed7567e3f55bb28edb176ab042882048d858 --- /dev/null +++ b/flow_matching/utils/manifolds/torus.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +from torch import Tensor + +from flow_matching.utils.manifolds import Manifold + + +class FlatTorus(Manifold): + r"""Represents a flat torus on the :math:`[0, 2\pi]^D` subspace. Isometric to the product of 1-D spheres.""" + + def expmap(self, x: Tensor, u: Tensor) -> Tensor: + return (x + u) % (2 * math.pi) + + def logmap(self, x: Tensor, y: Tensor) -> Tensor: + return torch.atan2(torch.sin(y - x), torch.cos(y - x)) + + def projx(self, x: Tensor) -> Tensor: + return x % (2 * math.pi) + + def proju(self, x: Tensor, u: Tensor) -> Tensor: + return u diff --git a/flow_matching/utils/manifolds/utils.py b/flow_matching/utils/manifolds/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b83d2faf8ac4f182ca8621f25cafaefca32f5091 --- /dev/null +++ b/flow_matching/utils/manifolds/utils.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch +from torch import Tensor + +from flow_matching.utils.manifolds import Manifold + + +def geodesic( + manifold: Manifold, start_point: Tensor, end_point: Tensor +) -> Callable[[Tensor], Tensor]: + """Generate parameterized function for geodesic curve. + + Args: + manifold (Manifold): the manifold to compute geodesic on. + start_point (Tensor): point on the manifold at :math:`t=0`. + end_point (Tensor): point on the manifold at :math:`t=1`. + + Returns: + Callable[[Tensor], Tensor]: a function that takes in :math:`t` and outputs the geodesic at time :math:`t`. + """ + + shooting_tangent_vec = manifold.logmap(start_point, end_point) + + def path(t: Tensor) -> Tensor: + """Generate parameterized function for geodesic curve. + + Args: + t (Tensor): Times at which to compute points of the geodesics. + + Returns: + Tensor: geodesic path evaluated at time t. + """ + tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec) + points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs) + + return points_at_time_t + + return path diff --git a/flow_matching/utils/model_wrapper.py b/flow_matching/utils/model_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ac7d932c851b3d2194bd63ce597ef924dc819a7f --- /dev/null +++ b/flow_matching/utils/model_wrapper.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC + +from torch import nn, Tensor + + +class ModelWrapper(ABC, nn.Module): + """ + This class is used to wrap around another model, adding custom forward pass logic. + """ + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor: + r""" + This method defines how inputs should be passed through the wrapped model. + Here, we're assuming that the wrapped model takes both :math:`x` and :math:`t` as input, + along with any additional keyword arguments. + + Optional things to do here: + - check that t is in the dimensions that the model is expecting. + - add a custom forward pass logic. + - call the wrapped model. + + | given x, t + | returns the model output for input x at time t, with extra information `extra`. + + Args: + x (Tensor): input data to the model (batch_size, ...). + t (Tensor): time (batch_size). + **extras: additional information forwarded to the model, e.g., text condition. + + Returns: + Tensor: model output. + """ + return self.model(x=x, t=t, **extras) diff --git a/flow_matching/utils/multi_guidance.py b/flow_matching/utils/multi_guidance.py new file mode 100644 index 0000000000000000000000000000000000000000..c0e33faf5d7510c07acaffb8711108fcab414474 --- /dev/null +++ b/flow_matching/utils/multi_guidance.py @@ -0,0 +1,216 @@ +import torch +from flow_matching.utils import categorical +import math +import inspect + +def generate_simplex_lattice_points(num_obj: int, num_div: int) -> torch.Tensor: + def rec(n, H): + if n == 1: + return [[H]] + points = [] + for i in range(H + 1): + for tail in rec(n - 1, H - i): + points.append([i] + tail) + return points + + points = rec(num_obj, num_div) + weight_vectors = torch.tensor(points, dtype=torch.float32) / num_div + return weight_vectors + +def select_random_weight_vector(num_obj: int, num_div: int): + weight_vectors = generate_simplex_lattice_points(num_obj, num_div) + idx = torch.randint(0, weight_vectors.size(0), (1,)).item() + random_weight_vector = weight_vectors[idx] + return random_weight_vector, weight_vectors + +def z_score_norm(tensor, eps=1e-8): + mean = tensor.mean(dim=-1, keepdim=True) + std = tensor.std(dim=-1, unbiased=False, keepdim=True).clamp(min=eps) + return (tensor - mean) / std + +def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, args): + B, L, vocab_size = u_t.shape + device = x_t.device + guided_u_t = u_t.clone() + + # 1. Randomly select one position per sequence. + pos_indices = torch.randint(low=1, high=L-2, size=(B,), device=device) # shape: (B,) # CHANGE! + batch_idx = torch.arange(B, device=device) + current_tokens = x_t[batch_idx, pos_indices] # shape: (B,) + + # 2. Build candidate tokens for each sequence and remove self-transition. + full_cand_tokens = torch.arange(vocab_size, device=device).unsqueeze(0).expand(B, vocab_size) # (B, vocab_size) + mask = (full_cand_tokens != current_tokens.unsqueeze(1)) # (B, vocab_size) + # Now, cand_tokens contains only candidate tokens that differ from the current token. + cand_tokens = torch.masked_select(full_cand_tokens, mask).view(B, vocab_size - 1) # (B, vocab_size-1) + + # 3. Create candidate sequences by replacing the token at the selected position. + new_x = x_t.unsqueeze(1).expand(B, vocab_size, L).clone() + new_x = new_x[mask].view(B, vocab_size - 1, L) # (B, vocab_size-1, L) + new_x[batch_idx, :, pos_indices] = cand_tokens + + new_x_flat = new_x.view(B * (vocab_size - 1), L) + improvements_list = [] + with torch.no_grad(): + count = 0 + for i, s in enumerate(s_models): + sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) + if 't' in sig.parameters: + candidate_scores = s(new_x_flat, t) + base_score = s(x_t, t) + else: + candidate_scores = s(new_x_flat) + base_score = s(x_t) + + if isinstance(candidate_scores, tuple): + for k, score in enumerate(candidate_scores): + improvement = candidate_scores[k].view(B, vocab_size - 1) - base_score[k].unsqueeze(1) + improvement = improvement.float() + improvement *= importance[count] + improvements_list.append(improvement.unsqueeze(2)) + count += 1 + else: + improvement = candidate_scores.view(B, vocab_size - 1) - base_score.unsqueeze(1) + improvement = improvement.float() + improvement *= importance[count] + improvements_list.append(improvement.unsqueeze(2)) # (B, vocab_size-1, 1) + count += 1 + + improvement_values = torch.cat(improvements_list, dim=2) # (B, vocab_size-1, N) + if args.is_peptide: + improvement_values[:, :4, :] = -10 # Mask non-residue positions + + # 5. Compute ranking scores I_n + ranks = torch.argsort(torch.argsort(improvement_values, dim=1), dim=1).float() + 1 # (B, vocab_size-1, N) + I_n = ranks / float(vocab_size - 1) + avg_I = I_n.mean(dim=2) + norm_avg_I = z_score_norm(avg_I) # (B, vocab_size-1) + + # 6. Compute directional score D + D = (improvement_values * w.view(1, 1, -1)).sum(dim=2) + norm_D = z_score_norm(D) # (B, vocab_size-1) + + # 7. Combine the scores + delta_S = norm_avg_I + args.lambda_ * norm_D # (B, vocab_size-1) + + # 9. Update the guided velocities at the selected positions. + factor = torch.exp(args.beta * delta_S) # (B, vocab_size-1) + factor = torch.clamp(factor, min=-100, max=100) + + guided_u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] = u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] * factor + + # 10. For the self-transition (current token) at the selected position, + # set its guided velocity to be the negative sum of the updated off-diagonals. + updated_vals = guided_u_t[batch_idx, pos_indices, :] # (B, vocab_size) + sum_off_diag = updated_vals.sum(dim=1) - updated_vals[batch_idx, current_tokens] + guided_u_t[batch_idx, pos_indices, current_tokens] = -sum_off_diag + + return guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S + +def adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=None): + B, num_candidates, N = improvement_values.shape + device = improvement_values.device + eps = 1e-8 + + # Compute norms and angles. + imp_norm = torch.norm(improvement_values.float(), dim=2) # (B, num_candidates) + dot_product = (improvement_values * w.view(1, 1, -1)).sum(dim=2) + w_norm = torch.norm(w) + eps + cos_angle = dot_product / (imp_norm * w_norm + eps) + cos_angle = cos_angle.clamp(-1.0, 1.0) + angles = torch.acos(cos_angle) # (B, num_candidates) + + valid_mask = angles < math.pi / 2 + accepted_mask = valid_mask & (angles <= Phi) # (B, num_candidates) + + # Determine the best candidate for each sequence. + # We'll use a loop over batch items (batch size is typically moderate). + best_candidate = torch.empty(B, dtype=torch.long, device=device) + for i in range(B): + # For sequence i, consider only valid candidates. + if valid_mask[i].any(): + # There is at least one candidate with α^i < π. + if accepted_mask[i].any(): + # At least one candidate passes the hypercone: choose the one with max delta_S among accepted. + candidate_idx = torch.argmax(delta_S[i].masked_fill(~accepted_mask[i], float('-inf'))) + else: + # No candidate was accepted, but some are valid. Select best candidate among valid ones. + candidate_idx = torch.argmax(delta_S[i].masked_fill(~valid_mask[i], float('-inf'))) + best_candidate[i] = cand_tokens[i, candidate_idx] + else: + # No candidate is valid (all α^i >= π) → self-transition. + best_candidate[i] = -1 + + # Compute rejection rate only over valid candidates. + rejection_rates = [] + for i in range(B): + valid_candidates = valid_mask[i] + total_valid = valid_candidates.sum().item() + if total_valid > 0: + # Among valid candidates, count how many are rejected. + num_rejected = (valid_candidates.sum() - accepted_mask[i].sum()).item() + rejection_rates.append(num_rejected / total_valid) + if len(rejection_rates) > 0: + r_t = sum(rejection_rates) / len(rejection_rates) + else: + # If no sequence has any valid candidate, set r_t to 0. + r_t = 0.0 + + if ema_r_t is None: + ema_r_t = args.tau + + # Update hypercone angle and ema rejection rate only if there is at least one valid candidate in the batch. + if valid_mask.any(): + new_ema_r_t = args.alpha_r * ema_r_t + (1 - args.alpha_r) * r_t + new_Phi = Phi * torch.exp(torch.tensor(args.eta * (new_ema_r_t - args.tau), device=device)) + new_Phi = new_Phi.clamp(args.Phi_min, args.Phi_max).item() + else: + new_ema_r_t = ema_r_t + new_Phi = Phi # No update if no valid candidate exists. + + return best_candidate, accepted_mask, valid_mask, new_Phi, new_ema_r_t + +def get_best_candidate(improvement_values, cand_tokens, delta_S): + B, num_candidates, N = improvement_values.shape + device = improvement_values.device + best_candidate = torch.empty(B, dtype=torch.long, device=device) + + for i in range(B): + candidate_idx = torch.argmax(delta_S[i]) + best_candidate[i] = cand_tokens[i, candidate_idx] + + return best_candidate + +def euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h): + B, L, V = guided_u_t.shape + device = x_t.device + u = torch.zeros_like(guided_u_t) + + valid_mask = best_candidate != -1 + if valid_mask.any(): + valid_idx = torch.nonzero(valid_mask).squeeze(-1) + # For these sequences, update the velocity at the selected position and candidate token. + u[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] = \ + guided_u_t[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] + + # Compute intensity at the selected positions. + # For sequences with no valid candidate (i.e. self-transition), intensity remains zero. + intensity = torch.zeros(B, device=device) + if valid_mask.any(): + intensity[valid_idx] = u[valid_idx, pos_indices[valid_idx]].sum(dim=-1) + + # According to the Euler Sampling formula, `p_jump` should be `1 - torch.exp(-h * intensity)` + # However, since `h = 1 / T` is small, p_jump becomes tiny and slows down sampling. + # To compensate, we scale `intensity` by T. We can do this because this is equivalent to setting `args.beta` to `T * args.beta`. + # So for faster sampling, we just use `1 - torch.exp(-1 * intensity)` + p_jump = 1 - torch.exp(-1 * intensity) + + rand_val = torch.rand(B, device=device) + + jump_decision = (rand_val < p_jump) & valid_mask + if True in jump_decision.tolist(): + print("Jump!") + # For sequences where a jump is decided, update the token at pos_indices to best_candidate. + x_t[jump_decision, pos_indices[jump_decision]] = best_candidate[jump_decision] + + return x_t diff --git a/flow_matching/utils/multi_guidance_cnp.py b/flow_matching/utils/multi_guidance_cnp.py new file mode 100644 index 0000000000000000000000000000000000000000..7dbdb213634b6ca140cf429c7a7f8484d9442ca8 --- /dev/null +++ b/flow_matching/utils/multi_guidance_cnp.py @@ -0,0 +1,217 @@ +import torch +from flow_matching.utils import categorical +import math +import inspect +import random + +def generate_simplex_lattice_points(num_obj: int, num_div: int) -> torch.Tensor: + def rec(n, H): + if n == 1: + return [[H]] + points = [] + for i in range(H + 1): + for tail in rec(n - 1, H - i): + points.append([i] + tail) + return points + + points = rec(num_obj, num_div) + weight_vectors = torch.tensor(points, dtype=torch.float32) / num_div + return weight_vectors + +def select_random_weight_vector(num_obj: int, num_div: int): + weight_vectors = generate_simplex_lattice_points(num_obj, num_div) + idx = torch.randint(0, weight_vectors.size(0), (1,)).item() + random_weight_vector = weight_vectors[idx] + return random_weight_vector, weight_vectors + +def z_score_norm(tensor, eps=1e-8): + mean = tensor.mean(dim=-1, keepdim=True) + std = tensor.std(dim=-1, unbiased=False, keepdim=True).clamp(min=eps) + return (tensor - mean) / std + +def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, args): + B, L, vocab_size = u_t.shape + device = x_t.device + guided_u_t = u_t.clone() + + # 1. Randomly select one position per sequence. + # pos_indices = torch.randint(low=1, high=L-2, size=(B,), device=device) # shape: (B,) # CHANGE! + pos_indices = torch.tensor([random.choice([i for i in range(1, L-2) if i != 6])]).to(x_t.device) + batch_idx = torch.arange(B, device=device) + current_tokens = x_t[batch_idx, pos_indices] # shape: (B,) + + # 2. Build candidate tokens for each sequence and remove self-transition. + full_cand_tokens = torch.arange(vocab_size, device=device).unsqueeze(0).expand(B, vocab_size) # (B, vocab_size) + mask = (full_cand_tokens != current_tokens.unsqueeze(1)) & (full_cand_tokens != 23) # (B, vocab_size) + # Now, cand_tokens contains only candidate tokens that differ from the current token. + cand_tokens = torch.masked_select(full_cand_tokens, mask).view(B, vocab_size - 2) # (B, vocab_size-1) + + # 3. Create candidate sequences by replacing the token at the selected position. + new_x = x_t.unsqueeze(1).expand(B, vocab_size, L).clone() + new_x = new_x[mask].view(B, vocab_size - 2, L) # (B, vocab_size-1, L) + new_x[batch_idx, :, pos_indices] = cand_tokens + + new_x_flat = new_x.view(B * (vocab_size - 2), L) + improvements_list = [] + with torch.no_grad(): + count = 0 + for i, s in enumerate(s_models): + sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) + if 't' in sig.parameters: + candidate_scores = s(new_x_flat, t) + base_score = s(x_t, t) + else: + candidate_scores = s(new_x_flat) + base_score = s(x_t) + + if isinstance(candidate_scores, tuple): + for k, score in enumerate(candidate_scores): + improvement = candidate_scores[k].view(B, vocab_size - 2) - base_score[k].unsqueeze(1) + improvement = improvement.float() + improvement *= importance[count] + improvements_list.append(improvement.unsqueeze(2)) + count += 1 + else: + improvement = candidate_scores.view(B, vocab_size - 2) - base_score.unsqueeze(1) + improvement = improvement.float() + improvement *= importance[count] + improvements_list.append(improvement.unsqueeze(2)) # (B, vocab_size-1, 1) + count += 1 + + improvement_values = torch.cat(improvements_list, dim=2) # (B, vocab_size-1, N) + if args.is_peptide: + improvement_values[:, :4, :] = -10 # Mask non-residue positions + + # 5. Compute ranking scores I_n + ranks = torch.argsort(torch.argsort(improvement_values, dim=1), dim=1).float() + 1 # (B, vocab_size-1, N) + I_n = ranks / float(vocab_size - 2) + avg_I = I_n.mean(dim=2) + norm_avg_I = z_score_norm(avg_I) # (B, vocab_size-1) + + # 6. Compute directional score D + D = (improvement_values * w.view(1, 1, -1)).sum(dim=2) + norm_D = z_score_norm(D) # (B, vocab_size-1) + + # 7. Combine the scores + delta_S = norm_avg_I + args.lambda_ * norm_D # (B, vocab_size-1) + + # 9. Update the guided velocities at the selected positions. + factor = torch.exp(args.beta * delta_S) # (B, vocab_size-1) + factor = torch.clamp(factor, min=-100, max=100) + + guided_u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] = u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] * factor + + # 10. For the self-transition (current token) at the selected position, + # set its guided velocity to be the negative sum of the updated off-diagonals. + updated_vals = guided_u_t[batch_idx, pos_indices, :] # (B, vocab_size) + sum_off_diag = updated_vals.sum(dim=1) - updated_vals[batch_idx, current_tokens] + guided_u_t[batch_idx, pos_indices, current_tokens] = -sum_off_diag + + return guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S + +def adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=None): + B, num_candidates, N = improvement_values.shape + device = improvement_values.device + eps = 1e-8 + + # Compute norms and angles. + imp_norm = torch.norm(improvement_values.float(), dim=2) # (B, num_candidates) + dot_product = (improvement_values * w.view(1, 1, -1)).sum(dim=2) + w_norm = torch.norm(w) + eps + cos_angle = dot_product / (imp_norm * w_norm + eps) + cos_angle = cos_angle.clamp(-1.0, 1.0) + angles = torch.acos(cos_angle) # (B, num_candidates) + + valid_mask = angles < math.pi / 2 + accepted_mask = valid_mask & (angles <= Phi) # (B, num_candidates) + + # Determine the best candidate for each sequence. + # We'll use a loop over batch items (batch size is typically moderate). + best_candidate = torch.empty(B, dtype=torch.long, device=device) + for i in range(B): + # For sequence i, consider only valid candidates. + if valid_mask[i].any(): + # There is at least one candidate with α^i < π. + if accepted_mask[i].any(): + # At least one candidate passes the hypercone: choose the one with max delta_S among accepted. + candidate_idx = torch.argmax(delta_S[i].masked_fill(~accepted_mask[i], float('-inf'))) + else: + # No candidate was accepted, but some are valid. Select best candidate among valid ones. + candidate_idx = torch.argmax(delta_S[i].masked_fill(~valid_mask[i], float('-inf'))) + best_candidate[i] = cand_tokens[i, candidate_idx] + else: + # No candidate is valid (all α^i >= π) → self-transition. + best_candidate[i] = -1 + + # Compute rejection rate only over valid candidates. + rejection_rates = [] + for i in range(B): + valid_candidates = valid_mask[i] + total_valid = valid_candidates.sum().item() + if total_valid > 0: + # Among valid candidates, count how many are rejected. + num_rejected = (valid_candidates.sum() - accepted_mask[i].sum()).item() + rejection_rates.append(num_rejected / total_valid) + if len(rejection_rates) > 0: + r_t = sum(rejection_rates) / len(rejection_rates) + else: + # If no sequence has any valid candidate, set r_t to 0. + r_t = 0.0 + + if ema_r_t is None: + ema_r_t = args.tau + + # Update hypercone angle and ema rejection rate only if there is at least one valid candidate in the batch. + if valid_mask.any(): + new_ema_r_t = args.alpha_r * ema_r_t + (1 - args.alpha_r) * r_t + new_Phi = Phi * torch.exp(torch.tensor(args.eta * (new_ema_r_t - args.tau), device=device)) + new_Phi = new_Phi.clamp(args.Phi_min, args.Phi_max).item() + else: + new_ema_r_t = ema_r_t + new_Phi = Phi # No update if no valid candidate exists. + + return best_candidate, accepted_mask, valid_mask, new_Phi, new_ema_r_t + +def get_best_candidate(improvement_values, cand_tokens, delta_S): + B, num_candidates, N = improvement_values.shape + device = improvement_values.device + best_candidate = torch.empty(B, dtype=torch.long, device=device) + + for i in range(B): + candidate_idx = torch.argmax(delta_S[i]) + best_candidate[i] = cand_tokens[i, candidate_idx] + + return best_candidate + +def euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h): + B, L, V = guided_u_t.shape + device = x_t.device + u = torch.zeros_like(guided_u_t) + + valid_mask = best_candidate != -1 + if valid_mask.any(): + valid_idx = torch.nonzero(valid_mask).squeeze(-1) + # For these sequences, update the velocity at the selected position and candidate token. + u[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] = \ + guided_u_t[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] + + # Compute intensity at the selected positions. + # For sequences with no valid candidate (i.e. self-transition), intensity remains zero. + intensity = torch.zeros(B, device=device) + if valid_mask.any(): + intensity[valid_idx] = u[valid_idx, pos_indices[valid_idx]].sum(dim=-1) + + # According to the Euler Sampling formula, `p_jump` should be `1 - torch.exp(-h * intensity)` + # However, since `h = 1 / T` is small, p_jump becomes tiny and slows down sampling. + # To compensate, we scale `intensity` by T. We can do this because this is equivalent to setting `args.beta` to `T * args.beta`. + # So for faster sampling, we just use `1 - torch.exp(-1 * intensity)` + p_jump = 1 - torch.exp(-1 * intensity) + + rand_val = torch.rand(B, device=device) + + jump_decision = (rand_val < p_jump) & valid_mask + + # For sequences where a jump is decided, update the token at pos_indices to best_candidate. + x_t[jump_decision, pos_indices[jump_decision]] = best_candidate[jump_decision] + + return x_t diff --git a/flow_matching/utils/utils.py b/flow_matching/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..beb31ff27343c40955ec52513c44f75a6ab5d665 --- /dev/null +++ b/flow_matching/utils/utils.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import Tensor + + +def unsqueeze_to_match(source: Tensor, target: Tensor, how: str = "suffix") -> Tensor: + """ + Unsqueeze the source tensor to match the dimensionality of the target tensor. + + Args: + source (Tensor): The source tensor to be unsqueezed. + target (Tensor): The target tensor to match the dimensionality of. + how (str, optional): Whether to unsqueeze the source tensor at the beginning + ("prefix") or end ("suffix"). Defaults to "suffix". + + Returns: + Tensor: The unsqueezed source tensor. + """ + assert ( + how == "prefix" or how == "suffix" + ), f"{how} is not supported, only 'prefix' and 'suffix' are supported." + + dim_diff = target.dim() - source.dim() + + for _ in range(dim_diff): + if how == "prefix": + source = source.unsqueeze(0) + elif how == "suffix": + source = source.unsqueeze(-1) + + return source + + +def expand_tensor_like(input_tensor: Tensor, expand_to: Tensor) -> Tensor: + """`input_tensor` is a 1d vector of length equal to the batch size of `expand_to`, + expand `input_tensor` to have the same shape as `expand_to` along all remaining dimensions. + + Args: + input_tensor (Tensor): (batch_size,). + expand_to (Tensor): (batch_size, ...). + + Returns: + Tensor: (batch_size, ...). + """ + assert input_tensor.ndim == 1, "Input tensor must be a 1d vector." + assert ( + input_tensor.shape[0] == expand_to.shape[0] + ), f"The first (batch_size) dimension must match. Got shape {input_tensor.shape} and {expand_to.shape}." + + dim_diff = expand_to.ndim - input_tensor.ndim + + t_expanded = input_tensor.clone() + t_expanded = t_expanded.reshape(-1, *([1] * dim_diff)) + + return t_expanded.expand_as(expand_to) + + +def gradient( + output: Tensor, + x: Tensor, + grad_outputs: Optional[Tensor] = None, + create_graph: bool = False, +) -> Tensor: + """ + Compute the gradient of the inner product of output and grad_outputs w.r.t :math:`x`. + + Args: + output (Tensor): [N, D] Output of the function. + x (Tensor): [N, d_1, d_2, ... ] input + grad_outputs (Optional[Tensor]): [N, D] Gradient of outputs, if `None`, + then will use a tensor of ones + create_graph (bool): If True, graph of the derivative will be constructed, allowing + to compute higher order derivative products. Defaults to False. + Returns: + Tensor: [N, d_1, d_2, ... ]. the gradient w.r.t x. + """ + + if grad_outputs is None: + grad_outputs = torch.ones_like(output).detach() + grad = torch.autograd.grad( + output, x, grad_outputs=grad_outputs, create_graph=create_graph + )[0] + return grad diff --git a/models/classifier.py b/models/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..ba325e5dd77e85c1895b5fdc419233b3a17a1bda --- /dev/null +++ b/models/classifier.py @@ -0,0 +1,116 @@ +from torch import nn +import torch.nn.functional as F +import torch +import numpy as np +import copy +import pdb + +class GaussianFourierProjection(nn.Module): + """ + Gaussian random features for encoding time steps. + """ + + def __init__(self, embed_dim, scale=30.): + super().__init__() + # Randomly sample weights during initialization. These weights are fixed + # during optimization and are not trainable. + self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False) + + def forward(self, x): + x_proj = x[:, None] * self.W[None, :] * 2 * np.pi + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + +class Dense(nn.Module): + """ + A fully connected layer that reshapes outputs to feature maps. + """ + + def __init__(self, input_dim, output_dim): + super().__init__() + self.dense = nn.Linear(input_dim, output_dim) + + def forward(self, x): + return self.dense(x)[...] + +class Swish(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sigmoid(x) * x + +class CNNClassifier(nn.Module): + def __init__(self, args, alphabet_size, num_cls, classifier=False): + super().__init__() + self.alphabet_size = alphabet_size + self.args = args + self.classifier = classifier + self.num_cls = num_cls + + if self.args.clean_data: + self.linear = nn.Embedding(self.alphabet_size, embedding_dim=args.hidden_dim) + else: + expanded_simplex_input = args.cls_expanded_simplex or not classifier and (args.mode == 'dirichlet' or args.mode == 'riemannian') + inp_size = self.alphabet_size * (2 if expanded_simplex_input else 1) + if (args.mode == 'ardm' or args.mode == 'lrar') and not classifier: + inp_size += 1 # plus one for the mask token of these models + self.linear = nn.Conv1d(inp_size, args.hidden_dim, kernel_size=9, padding=4) + self.time_embedder = nn.Sequential(GaussianFourierProjection(embed_dim= args.hidden_dim),nn.Linear(args.hidden_dim, args.hidden_dim)) + + self.num_layers = 5 * args.num_cnn_stacks + self.convs = [nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4), + nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4), + nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=4, padding=16), + nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=16, padding=64), + nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=64, padding=256)] + self.convs = nn.ModuleList([copy.deepcopy(layer) for layer in self.convs for i in range(args.num_cnn_stacks)]) + self.time_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)]) + self.norms = nn.ModuleList([nn.LayerNorm(args.hidden_dim) for _ in range(self.num_layers)]) + self.final_conv = nn.Sequential(nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=1), + nn.ReLU(), + nn.Conv1d(args.hidden_dim, args.hidden_dim if classifier else self.alphabet_size, kernel_size=1)) + self.dropout = nn.Dropout(args.dropout) + if classifier: + self.cls_head = nn.Sequential(nn.Linear(args.hidden_dim, args.hidden_dim), + nn.ReLU(), + nn.Linear(args.hidden_dim, self.num_cls)) + + if self.args.cls_free_guidance and not self.classifier: + self.cls_embedder = nn.Embedding(num_embeddings=self.num_cls + 1, embedding_dim=args.hidden_dim) + self.cls_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)]) + def forward(self, seq, t, cls = None, return_embedding=False): + # pdb.set_trace() + if self.args.clean_data: + feat = self.linear(seq) + feat = feat.permute(0, 2, 1) + else: + time_emb = F.relu(self.time_embedder(t)) + feat = seq.permute(0, 2, 1) + feat = F.relu(self.linear(feat)) + + if self.args.cls_free_guidance and not self.classifier and cls is not None: + # pdb.set_trace() + cls_emb = self.cls_embedder(cls) + + for i in range(self.num_layers): + h = self.dropout(feat.clone()) + if not self.args.clean_data: + h = h + self.time_layers[i](time_emb)[:, :, None] + if self.args.cls_free_guidance and not self.classifier and cls is not None: + h = h + self.cls_layers[i](cls_emb)[:, :, None] + h = self.norms[i]((h).permute(0, 2, 1)) + h = F.relu(self.convs[i](h.permute(0, 2, 1))) + if h.shape == feat.shape: + feat = h + feat + else: + feat = h + feat = self.final_conv(feat) + feat = feat.permute(0, 2, 1) + if self.classifier: + feat = feat.mean(dim=1) + if return_embedding: + embedding = self.cls_head[:1](feat) + return self.cls_head[1:](embedding), embedding + else: + return self.cls_head(feat) + return feat \ No newline at end of file diff --git a/models/enhancer_models.py b/models/enhancer_models.py new file mode 100644 index 0000000000000000000000000000000000000000..d2123d1bc133553c2fa52c5c32d3f0cca6b831b4 --- /dev/null +++ b/models/enhancer_models.py @@ -0,0 +1,215 @@ +from torch import nn +import torch +import numpy as np +import torch.nn.functional as F +import copy +import pdb + +class GaussianFourierProjection(nn.Module): + """ + Gaussian random features for encoding time steps. + """ + + def __init__(self, embed_dim, scale=30.): + super().__init__() + # Randomly sample weights during initialization. These weights are fixed + # during optimization and are not trainable. + self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False) + + def forward(self, x): + x_proj = x[:, None] * self.W[None, :] * 2 * np.pi + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + +class Dense(nn.Module): + """ + A fully connected layer that reshapes outputs to feature maps. + """ + + def __init__(self, input_dim, output_dim): + super().__init__() + self.dense = nn.Linear(input_dim, output_dim) + + def forward(self, x): + return self.dense(x)[...] + +class Swish(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sigmoid(x) * x + +class CNNModel(nn.Module): + """A time-dependent score-based model built upon U-Net architecture.""" + + def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256): + """ + Args: + embed_dim (int): Dimensionality of the token and time embeddings. + """ + super().__init__() + self.alphabet_size = alphabet_size + + self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim) + + self.time_embed = nn.Sequential( + GaussianFourierProjection(embed_dim=embed_dim), + nn.Linear(embed_dim, embed_dim) + ) + + self.swish = Swish() + + n = hidden_dim + + self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4) + + self.blocks = nn.ModuleList([ + nn.Conv1d(n, n, kernel_size=9, padding=4), + nn.Conv1d(n, n, kernel_size=9, padding=4), + nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256) + ]) + + self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(5)]) + self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(5)]) + + self.final = nn.Sequential( + nn.Conv1d(n, n, kernel_size=1), + nn.GELU(), + nn.Conv1d(n, self.alphabet_size, kernel_size=1) + ) + + + def forward(self, x, t): + """ + Args: + x: Tensor of shape (B, L) containing DNA token indices. + t: Tensor of shape (B,) containing the time steps. + Returns: + out: Tensor of shape (B, L, 4) with output logits for each DNA base. + """ + x = self.token_embedding(x) # (B, L) -> (B, L, embed_dim) + + time_embed = self.swish(self.time_embed(t)) # (B, embed_dim) + + out = x.permute(0, 2, 1) # (B, L, embed_dim) -> (B, embed_dim, L) + out = self.swish(self.linear(out)) # (B, n, L) + + # Process through convolutional blocks, adding time conditioning via dense layers. + for block, dense, norm in zip(self.blocks, self.denses, self.norms): + # dense(embed) gives (B, n); unsqueeze to (B, n, 1) for broadcasting. + h = self.swish(block(norm(out + dense(time_embed)[:, :, None]))) + # Residual connection if shapes match. + if h.shape == out.shape: + out = h + out + else: + out = h + + out = self.final(out) # (B, 4, L) + out = out.permute(0, 2, 1) # (B, L, 4) + + # Normalization + out = out - out.mean(dim=-1, keepdim=True) + return out + + +class MLPModel(nn.Module): + def __init__( + self, input_dim: int = 128, time_dim: int = 1, hidden_dim=128, length=500): + super().__init__() + self.input_dim = input_dim + self.time_dim = time_dim + self.hidden_dim = hidden_dim + + self.time_embedding = nn.Linear(1, time_dim) + self.token_embedding = torch.nn.Embedding(self.input_dim, hidden_dim) + + self.swish = Swish() + + self.main = nn.Sequential( + self.swish, + nn.Linear(hidden_dim * length + time_dim, hidden_dim), + self.swish, + nn.Linear(hidden_dim, hidden_dim), + self.swish, + nn.Linear(hidden_dim, hidden_dim), + self.swish, + nn.Linear(hidden_dim, self.input_dim * length), + ) + + def forward(self, x, t): + ''' + x shape (B,L) + t shape (B,) + ''' + t = self.time_embedding(t.unsqueeze(-1)) + x = self.token_embedding(x) + + B, N, d = x.shape + x = x.reshape(B, N * d) + + h = torch.cat([x, t], dim=1) + h = self.main(h) + + h = h.reshape(B, N, self.input_dim) + + return h + +class DirichletCNNModel(nn.Module): + def __init__(self, args, alphabet_size): + super().__init__() + self.alphabet_size = alphabet_size + self.args = args + expanded_simplex_input = args.cls_expanded_simplex and (args.mode == 'dirichlet' or args.mode == 'riemannian') + inp_size = self.alphabet_size * (2 if expanded_simplex_input else 1) + self.linear = nn.Conv1d(inp_size, args.hidden_dim, kernel_size=9, padding=4) + self.time_embedder = nn.Sequential(GaussianFourierProjection(embed_dim= args.hidden_dim),nn.Linear(args.hidden_dim, args.hidden_dim)) + + self.num_layers = 5 * args.num_cnn_stacks + self.convs = [nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4), + nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4), + nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=4, padding=16), + nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=16, padding=64), + nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=64, padding=256)] + self.convs = nn.ModuleList([copy.deepcopy(layer) for layer in self.convs for i in range(args.num_cnn_stacks)]) + self.time_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)]) + self.norms = nn.ModuleList([nn.LayerNorm(args.hidden_dim) for _ in range(self.num_layers)]) + self.final_conv = nn.Sequential(nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=1), + nn.ReLU(), + nn.Conv1d(args.hidden_dim, self.alphabet_size, kernel_size=1)) + self.dropout = nn.Dropout(args.dropout) + + def forward(self, seq, t): + time_emb = F.relu(self.time_embedder(t)) + feat = seq.permute(0, 2, 1) + feat = F.relu(self.linear(feat)) + + for i in range(self.num_layers): + h = self.dropout(feat.clone()) + if not self.args.clean_data: + h = h + self.time_layers[i](time_emb)[:, :, None] + h = self.norms[i]((h).permute(0, 2, 1)) + h = F.relu(self.convs[i](h.permute(0, 2, 1))) + if h.shape == feat.shape: + feat = h + feat + else: + feat = h + feat = self.final_conv(feat) + feat = feat.permute(0, 2, 1) + return feat \ No newline at end of file diff --git a/models/peptide_classifiers.py b/models/peptide_classifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..6e4e122c0af2f3f1246605754f961b4f69386616 --- /dev/null +++ b/models/peptide_classifiers.py @@ -0,0 +1,751 @@ +import pdb +import torch +import torch.nn.functional as F +import torch.nn as nn +import pytorch_lightning as pl +import time +from transformers import AutoModel, AutoConfig, AutoTokenizer +import xgboost as xgb +import esm + +from flow_matching.path import MixtureDiscreteProbPath +from flow_matching.path.scheduler import PolynomialConvexScheduler +from flow_matching.solver import MixtureDiscreteEulerSolver +from flow_matching.utils import ModelWrapper +from flow_matching.loss import MixturePathGeneralizedKL + +from models.peptide_models import CNNModel +from modules.bindevaluator_modules import * + +def parse_motifs(motif: str) -> list: + parts = motif.split(',') + result = [] + + for part in parts: + part = part.strip() + if '-' in part: + start, end = map(int, part.split('-')) + result.extend(range(start, end + 1)) + else: + result.append(int(part)) + + result = [pos-1 for pos in result] + print(f'Target Motifs: {result}') + return torch.tensor(result) + +class BindEvaluator(pl.LightningModule): + def __init__(self, n_layers, d_model, d_hidden, n_head, + d_k, d_v, d_inner, dropout=0.2, + learning_rate=0.00001, max_epochs=15, kl_weight=1): + super(BindEvaluator, self).__init__() + + self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") + self.esm_model.eval() + # freeze all the esm_model parameters + for param in self.esm_model.parameters(): + param.requires_grad = False + + self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden, + n_head, d_k, d_v, d_inner, dropout=dropout) + + self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, + d_k, d_v, dropout=dropout) + + self.final_ffn = FFN(d_model, d_inner, dropout=dropout) + + self.output_projection_prot = nn.Linear(d_model, 1) + + self.learning_rate = learning_rate + self.max_epochs = max_epochs + self.kl_weight = kl_weight + + self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold + self.historical_memory = 0.9 + self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights + + def forward(self, binder_tokens, target_tokens): + peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state + protein_sequence = self.esm_model(**target_tokens).last_hidden_state + + prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ + seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, + protein_sequence) + + prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) + + prot_enc = self.final_ffn(prot_enc) + + prot_enc = self.output_projection_prot(prot_enc) + + return prot_enc + + def get_probs(self, x_t, target_sequence): + ''' + Inputs: + - xt: Shape (bsz, seq_len) + - target_sequence: Shape (1, tgt_len) + ''' + # pdb.set_trace() + target_sequence = target_sequence.repeat(x_t.shape[0], 1) + binder_attention_mask = torch.ones_like(x_t) + target_attention_mask = torch.ones_like(target_sequence) + + binder_attention_mask[:, 0] = binder_attention_mask[:, -1] = 0 + target_attention_mask[:, 0] = target_attention_mask[:, -1] = 0 + + binder_tokens = {'input_ids': x_t, 'attention_mask': binder_attention_mask.to(x_t.device)} + target_tokens = {'input_ids': target_sequence, 'attention_mask': target_attention_mask.to(target_sequence.device)} + + logits = self.forward(binder_tokens, target_tokens).squeeze(-1) + # pdb.set_trace() + logits[:, 0] = logits[:, -1] = -100 # float('-inf') + probs = torch.sigmoid(logits) + + return probs # shape (bsz, tgt_len) + + def motif_score(self, x_t, target_sequence, motifs): + probs = self.get_probs(x_t, target_sequence) + motif_probs = probs[:, motifs] + motif_score = motif_probs.sum(dim=-1) / len(motifs) + # pdb.set_trace() + return motif_score + + def non_motif_score(self, x_t, target_sequence, motifs): + probs = self.get_probs(x_t, target_sequence) + non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]] + mask = non_motif_probs >= 0.5 + count = mask.sum(dim=-1) + + non_motif_score = torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count)) + + return non_motif_score + + def scoring(self, x_t, target_sequence, motifs, penalty=False): + probs = self.get_probs(x_t, target_sequence) + motif_probs = probs[:, motifs] + motif_score = motif_probs.sum(dim=-1) / len(motifs) + # pdb.set_trace() + + if penalty: + non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]] + mask = non_motif_probs >= 0.5 + count = mask.sum(dim=-1) + # non_motif_score = 1 - torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count)) + non_motif_score = count / target_sequence.shape[1] + return motif_score, 1 - non_motif_score + else: + return motif_score + +class MotifModel(nn.Module): + def __init__(self, bindevaluator, target_sequence, motifs, penalty=False): + super(MotifModel, self).__init__() + self.bindevaluator = bindevaluator + self.target_sequence = target_sequence + self.motifs = motifs + self.penalty = penalty + + def forward(self, x): + return self.bindevaluator.scoring(x, self.target_sequence, self.motifs, self.penalty) + +class UnpooledBindingPredictor(nn.Module): + def __init__(self, + esm_model_name="facebook/esm2_t33_650M_UR50D", + hidden_dim=512, + kernel_sizes=[3, 5, 7], + n_heads=8, + n_layers=3, + dropout=0.1, + freeze_esm=True): + super().__init__() + + # Define binding thresholds + self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM + self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM + + # Load ESM model for computing embeddings on the fly + self.esm_model = AutoModel.from_pretrained(esm_model_name) + self.config = AutoConfig.from_pretrained(esm_model_name) + + # Freeze ESM parameters if needed + if freeze_esm: + for param in self.esm_model.parameters(): + param.requires_grad = False + + # Get ESM hidden size + esm_dim = self.config.hidden_size + + # Output channels for CNN layers + output_channels_per_kernel = 64 + + # CNN layers for handling variable length sequences + self.protein_conv_layers = nn.ModuleList([ + nn.Conv1d( + in_channels=esm_dim, + out_channels=output_channels_per_kernel, + kernel_size=k, + padding='same' + ) for k in kernel_sizes + ]) + + self.binder_conv_layers = nn.ModuleList([ + nn.Conv1d( + in_channels=esm_dim, + out_channels=output_channels_per_kernel, + kernel_size=k, + padding='same' + ) for k in kernel_sizes + ]) + + # Calculate total features after convolution and pooling + total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2 + + # Project to same dimension after CNN processing + self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim) + self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim) + + self.protein_norm = nn.LayerNorm(hidden_dim) + self.binder_norm = nn.LayerNorm(hidden_dim) + + # Cross attention blocks with layer norm + self.cross_attention_layers = nn.ModuleList([ + nn.ModuleDict({ + 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), + 'norm1': nn.LayerNorm(hidden_dim), + 'ffn': nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * 4), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim * 4, hidden_dim) + ), + 'norm2': nn.LayerNorm(hidden_dim) + }) for _ in range(n_layers) + ]) + + # Prediction heads + self.shared_head = nn.Sequential( + nn.Linear(hidden_dim * 2, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + ) + + # Regression head + self.regression_head = nn.Linear(hidden_dim, 1) + + # Classification head (3 classes: tight, medium, loose binding) + self.classification_head = nn.Linear(hidden_dim, 3) + + def get_binding_class(self, affinity): + """Convert affinity values to class indices + 0: tight binding (>= 7.5) + 1: medium binding (6.0-7.5) + 2: weak binding (< 6.0) + """ + if isinstance(affinity, torch.Tensor): + tight_mask = affinity >= self.tight_threshold + weak_mask = affinity < self.weak_threshold + medium_mask = ~(tight_mask | weak_mask) + + classes = torch.zeros_like(affinity, dtype=torch.long) + classes[medium_mask] = 1 + classes[weak_mask] = 2 + return classes + else: + if affinity >= self.tight_threshold: + return 0 # tight binding + elif affinity < self.weak_threshold: + return 2 # weak binding + else: + return 1 # medium binding + + def compute_embeddings(self, input_ids, attention_mask=None): + """Compute ESM embeddings on the fly""" + esm_outputs = self.esm_model( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True + ) + + # Get the unpooled last hidden states (batch_size x seq_length x hidden_size) + return esm_outputs.last_hidden_state + + def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None): + """Process a sequence through CNN layers and pooling""" + # Transpose for CNN: [batch_size, hidden_size, seq_length] + x = unpooled_emb.transpose(1, 2) + + # Apply CNN layers and collect outputs + conv_outputs = [] + for conv in conv_layers: + conv_out = F.relu(conv(x)) + conv_outputs.append(conv_out) + + # Concatenate along channel dimension + conv_output = torch.cat(conv_outputs, dim=1) + + # Global pooling (both max and average) + # If attention mask is provided, use it to create a proper mask for pooling + if attention_mask is not None: + # Create a mask for pooling (1 for valid positions, 0 for padding) + # Expand mask to match conv_output channels + expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1) + + # Apply mask (set padding to large negative value for max pooling) + masked_output = conv_output.clone() + masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf')) + + # Max pooling along sequence dimension + max_pooled = torch.max(masked_output, dim=2)[0] + + # Average pooling (sum divided by number of valid positions) + sum_pooled = torch.sum(conv_output * expanded_mask, dim=2) + valid_positions = torch.sum(expanded_mask, dim=2) + valid_positions = torch.clamp(valid_positions, min=1.0) # Avoid division by zero + avg_pooled = sum_pooled / valid_positions + else: + # If no mask, use standard pooling + max_pooled = torch.max(conv_output, dim=2)[0] + avg_pooled = torch.mean(conv_output, dim=2) + + # Concatenate the pooled features + pooled = torch.cat([max_pooled, avg_pooled], dim=1) + + return pooled + + def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None): + # Compute embeddings on the fly using the ESM model + protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask) + binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask) + + # Process protein and binder sequences through CNN layers + protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask) + binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask) + + # Project to same dimension + protein = self.protein_norm(self.protein_projection(protein_features)) + binder = self.binder_norm(self.binder_projection(binder_features)) + + # Reshape for attention: from [batch_size, hidden_dim] to [1, batch_size, hidden_dim] + protein = protein.unsqueeze(0) + binder = binder.unsqueeze(0) + + # Cross attention layers + for layer in self.cross_attention_layers: + # Protein attending to binder + attended_protein = layer['attention']( + protein, binder, binder + )[0] + protein = layer['norm1'](protein + attended_protein) + protein = layer['norm2'](protein + layer['ffn'](protein)) + + # Binder attending to protein + attended_binder = layer['attention']( + binder, protein, protein + )[0] + binder = layer['norm1'](binder + attended_binder) + binder = layer['norm2'](binder + layer['ffn'](binder)) + + # Remove sequence dimension + protein_pool = protein.squeeze(0) + binder_pool = binder.squeeze(0) + + # Concatenate both representations + combined = torch.cat([protein_pool, binder_pool], dim=-1) + + # Shared features + shared_features = self.shared_head(combined) + + regression_output = self.regression_head(shared_features) + # classification_logits = self.classification_head(shared_features) + + # return regression_output, classification_logits + return regression_output + +class ImprovedBindingPredictor(nn.Module): + def __init__(self, + esm_dim=1280, + smiles_dim=1280, + hidden_dim=512, + n_heads=8, + n_layers=5, + dropout=0.1): + super().__init__() + + # Define binding thresholds + self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM + self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM + + # Project to same dimension + self.smiles_projection = nn.Linear(smiles_dim, hidden_dim) + self.protein_projection = nn.Linear(esm_dim, hidden_dim) + self.protein_norm = nn.LayerNorm(hidden_dim) + self.smiles_norm = nn.LayerNorm(hidden_dim) + + # Cross attention blocks with layer norm + self.cross_attention_layers = nn.ModuleList([ + nn.ModuleDict({ + 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), + 'norm1': nn.LayerNorm(hidden_dim), + 'ffn': nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * 4), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim * 4, hidden_dim) + ), + 'norm2': nn.LayerNorm(hidden_dim) + }) for _ in range(n_layers) + ]) + + # Prediction heads + self.shared_head = nn.Sequential( + nn.Linear(hidden_dim * 2, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + ) + + # Regression head + self.regression_head = nn.Linear(hidden_dim, 1) + + # Classification head (3 classes: tight, medium, loose binding) + self.classification_head = nn.Linear(hidden_dim, 3) + + def get_binding_class(self, affinity): + """Convert affinity values to class indices + 0: tight binding (>= 7.5) + 1: medium binding (6.0-7.5) + 2: weak binding (< 6.0) + """ + if isinstance(affinity, torch.Tensor): + tight_mask = affinity >= self.tight_threshold + weak_mask = affinity < self.weak_threshold + medium_mask = ~(tight_mask | weak_mask) + + classes = torch.zeros_like(affinity, dtype=torch.long) + classes[medium_mask] = 1 + classes[weak_mask] = 2 + return classes + else: + if affinity >= self.tight_threshold: + return 0 # tight binding + elif affinity < self.weak_threshold: + return 2 # weak binding + else: + return 1 # medium binding + + def forward(self, protein_emb, binder_emb): + + protein = self.protein_norm(self.protein_projection(protein_emb)) + smiles = self.smiles_norm(self.smiles_projection(binder_emb)) + + protein = protein.transpose(0, 1) + smiles = smiles.transpose(0, 1) + + # Cross attention layers + for layer in self.cross_attention_layers: + # Protein attending to SMILES + attended_protein = layer['attention']( + protein, smiles, smiles + )[0] + protein = layer['norm1'](protein + attended_protein) + protein = layer['norm2'](protein + layer['ffn'](protein)) + + # SMILES attending to protein + attended_smiles = layer['attention']( + smiles, protein, protein + )[0] + smiles = layer['norm1'](smiles + attended_smiles) + smiles = layer['norm2'](smiles + layer['ffn'](smiles)) + + # Get sequence-level representations + protein_pool = torch.mean(protein, dim=0) + smiles_pool = torch.mean(smiles, dim=0) + + # Concatenate both representations + combined = torch.cat([protein_pool, smiles_pool], dim=-1) + + # Shared features + shared_features = self.shared_head(combined) + + regression_output = self.regression_head(shared_features) + + return regression_output + +class PooledAffinityModel(nn.Module): + def __init__(self, affinity_predictor, target_sequence): + super(PooledAffinityModel, self).__init__() + self.affinity_predictor = affinity_predictor + self.target_sequence = target_sequence + self.esm_model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(self.target_sequence.device) + for param in self.esm_model.parameters(): + param.requires_grad = False + + def compute_embeddings(self, input_ids, attention_mask=None): + """Compute ESM embeddings on the fly""" + esm_outputs = self.esm_model( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True + ) + + # Get the unpooled last hidden states (batch_size x seq_length x hidden_size) + return esm_outputs.last_hidden_state + + def forward(self, x): + target_sequence = self.target_sequence.repeat(x.shape[0], 1) + + protein_emb = self.compute_embeddings(input_ids=target_sequence) + binder_emb = self.compute_embeddings(input_ids=x) + return self.affinity_predictor(protein_emb=protein_emb, binder_emb=binder_emb).squeeze(-1) + +class AffinityModel(nn.Module): + def __init__(self, affinity_predictor, target_sequence): + super(AffinityModel, self).__init__() + self.affinity_predictor = affinity_predictor + self.target_sequence = target_sequence + + def forward(self, x): + target_sequence = self.target_sequence.repeat(x.shape[0], 1) + affinity = self.affinity_predictor(protein_input_ids=target_sequence, binder_input_ids=x).squeeze(-1) + return affinity / 10 + +class HemolysisModel: + def __init__(self, device): + self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_hemolysis.json') + + self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device) + self.model.eval() + + self.device = device + + def generate_embeddings(self, sequences): + """Generate ESM embeddings for protein sequences""" + with torch.no_grad(): + embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1) + embeddings = embeddings.cpu().numpy() + + return embeddings + + def get_scores(self, input_seqs): + scores = np.ones(len(input_seqs)) + features = self.generate_embeddings(input_seqs) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + probs = self.predictor.predict(features) + # return the probability of it being not hemolytic + return torch.from_numpy(scores - probs).to(self.device) + + def __call__(self, input_seqs: list): + scores = self.get_scores(input_seqs) + return scores + +class NonfoulingModel: + def __init__(self, device): + # change model path + self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_nonfouling.json') + + self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device) + self.model.eval() + + self.device = device + + def generate_embeddings(self, sequences): + """Generate ESM embeddings for protein sequences""" + with torch.no_grad(): + embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1) + embeddings = embeddings.cpu().numpy() + + return embeddings + + def get_scores(self, input_seqs): + scores = np.zeros(len(input_seqs)) + features = self.generate_embeddings(input_seqs) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + scores = self.predictor.predict(features) + return torch.from_numpy(scores).to(self.device) + + def __call__(self, input_seqs: list): + scores = self.get_scores(input_seqs) + return scores + +class SolubilityModel: + def __init__(self, device): + # change model path + self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_solubility.json') + + self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device) + self.model.eval() + + self.device = device + + def generate_embeddings(self, sequences): + """Generate ESM embeddings for protein sequences""" + with torch.no_grad(): + embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1) + embeddings = embeddings.cpu().numpy() + + return embeddings + + def get_scores(self, input_seqs: list): + scores = np.zeros(len(input_seqs)) + features = self.generate_embeddings(input_seqs) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + scores = self.predictor.predict(features) + return torch.from_numpy(scores).to(self.device) + + def __call__(self, input_seqs: list): + scores = self.get_scores(input_seqs) + return scores + +class SolubilityModelNew: + def __init__(self, device): + self.hydro_ids = torch.tensor([5, 7, 4, 12, 20, 18, 22, 14], device=device) + self.device = device + + def get_scores(self, x): + mask = (x.unsqueeze(-1) == self.hydro_ids).any(dim=-1) + ratios = mask.float().mean(dim=1) + return 1 - ratios + + def __call__(self, input_seqs: list): + scores = self.get_scores(input_seqs) + return scores + +class PeptideCNN(nn.Module): + def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate): + super().__init__() + self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1) + self.fc = nn.Linear(hidden_dims[1], output_dim) + self.dropout = nn.Dropout(dropout_rate) + self.predictor = nn.Linear(output_dim, 1) # For regression/classification + + self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") + self.esm_model.eval() + + def forward(self, input_ids, attention_mask=None, return_features=False): + with torch.no_grad(): + x = self.esm_model(input_ids, attention_mask).last_hidden_state + # x shape: (B, L, input_dim) + x = x.permute(0, 2, 1) # Reshape to (B, input_dim, L) for Conv1d + x = nn.functional.relu(self.conv1(x)) + x = self.dropout(x) + x = nn.functional.relu(self.conv2(x)) + x = self.dropout(x) + x = x.permute(0, 2, 1) # Reshape back to (B, L, hidden_dims[1]) + + # Global average pooling over the sequence dimension (L) + x = x.mean(dim=1) # Shape: (B, hidden_dims[1]) + + features = self.fc(x) # features shape: (B, output_dim) + if return_features: + return features + return self.predictor(features) # Output shape: (B, 1) + +class HalfLifeModel: + def __init__(self, device): + input_dim = 1280 + hidden_dims = [input_dim // 2, input_dim // 4] + output_dim = input_dim // 8 + dropout_rate = 0.3 + self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device) + self.model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False)) + self.model.eval() + + def __call__(self, x): + prediction = self.model(x, return_features=False) + halflife = torch.clamp(prediction.squeeze(-1), max=2.0, min=0.0) + return halflife / 2 + + +def load_bindevaluator(checkpoint_path, device): + bindevaluator = BindEvaluator.load_from_checkpoint(checkpoint_path, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device) + bindevaluator.eval() + for param in bindevaluator.parameters(): + param.requires_grad = False + + return bindevaluator + + +def load_solver(checkpoint_path, vocab_size, device): + lr = 1e-4 + epochs = 200 + embed_dim = 512 + hidden_dim = 256 + epsilon = 1e-3 + batch_size = 256 + warmup_epochs = epochs // 10 + device = 'cuda:0' + + + probability_denoiser = CNNModel(alphabet_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim).to(device) + probability_denoiser.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=False)) + probability_denoiser.eval() + for param in probability_denoiser.parameters(): + param.requires_grad = False + + # instantiate a convex path object + scheduler = PolynomialConvexScheduler(n=2.0) + path = MixtureDiscreteProbPath(scheduler=scheduler) + + class WrappedModel(ModelWrapper): + def forward(self, x: torch.Tensor, t: torch.Tensor, **extras): + return torch.softmax(self.model(x, t), dim=-1) + + wrapped_probability_denoiser = WrappedModel(probability_denoiser) + solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size) + + return solver + + +def load_pooled_affinity_predictor(checkpoint_path, device): + """Load trained model from checkpoint.""" + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + + model = ImprovedBindingPredictor().to(device) + + # Load the trained weights + model.load_state_dict(checkpoint['model_state_dict']) + model.eval() # Set to evaluation mode + + return model + +def load_affinity_predictor(checkpoint_path, device): + """Load trained model from checkpoint.""" + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + + model = UnpooledBindingPredictor( + esm_model_name="facebook/esm2_t33_650M_UR50D", + hidden_dim=384, + kernel_sizes=[3, 5, 7], + n_heads=8, + n_layers=4, + dropout=0.14561457009902096, + freeze_esm=True + ).to(device) + + model.load_state_dict(checkpoint['model_state_dict']) + model.eval() + + return model diff --git a/models/peptide_models.py b/models/peptide_models.py new file mode 100644 index 0000000000000000000000000000000000000000..8fb81e23b1bef87f52fb8aff7d2440aabd9ac21e --- /dev/null +++ b/models/peptide_models.py @@ -0,0 +1,359 @@ +from torch import nn +import torch +import numpy as np +from transformers import AutoModel +import torch.nn.functional as F +import esm +import copy +import pdb + +class GaussianFourierProjection(nn.Module): + """ + Gaussian random features for encoding time steps. + """ + + def __init__(self, embed_dim, scale=30.): + super().__init__() + # Randomly sample weights during initialization. These weights are fixed + # during optimization and are not trainable. + self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False) + + def forward(self, x): + x_proj = x[:, None] * self.W[None, :] * 2 * np.pi + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + +class Dense(nn.Module): + """ + A fully connected layer that reshapes outputs to feature maps. + """ + + def __init__(self, input_dim, output_dim): + super().__init__() + self.dense = nn.Linear(input_dim, output_dim) + + def forward(self, x): + return self.dense(x)[...] + +class Swish(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sigmoid(x) * x + +class CNNESMModel(nn.Module): + """A time-dependent score-based model built upon U-Net architecture.""" + + def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256): + """ + Args: + embed_dim (int): Dimensionality of the token and time embeddings. + """ + super().__init__() + self.alphabet_size = alphabet_size + + # self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim) + self.esm = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D") + self.esm.eval() + for param in self.esm.parameters(): + param.requires_grad = False + + self.time_embed = nn.Sequential( + GaussianFourierProjection(embed_dim=embed_dim), + nn.Linear(embed_dim, embed_dim) + ) + + self.swish = Swish() + + n = hidden_dim + + self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4) + + self.blocks = nn.ModuleList([ + nn.Conv1d(n, n, kernel_size=9, padding=4), + nn.Conv1d(n, n, kernel_size=9, padding=4), + nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256) + ]) + + self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(5)]) + self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(5)]) + + self.final = nn.Sequential( + nn.Conv1d(n, n, kernel_size=1), + nn.GELU(), + nn.Conv1d(n, self.alphabet_size, kernel_size=1) + ) + + + def forward(self, x, t): + """ + Args: + x: Tensor of shape (B, L) containing DNA token indices. + t: Tensor of shape (B,) containing the time steps. + Returns: + out: Tensor of shape (B, L, 4) with output logits for each DNA base. + """ + # x = self.token_embedding(x) # (B, L) -> (B, L, embed_dim) + with torch.no_grad(): + x = self.esm(input_ids=x).last_hidden_state + time_embed = self.swish(self.time_embed(t)) # (B, embed_dim) + + out = x.permute(0, 2, 1) # (B, L, embed_dim) -> (B, embed_dim, L) + out = self.swish(self.linear(out)) # (B, n, L) + + # Process through convolutional blocks, adding time conditioning via dense layers. + for block, dense, norm in zip(self.blocks, self.denses, self.norms): + # dense(embed) gives (B, n); unsqueeze to (B, n, 1) for broadcasting. + h = self.swish(block(norm(out + dense(time_embed)[:, :, None]))) + # Residual connection if shapes match. + if h.shape == out.shape: + out = h + out + else: + out = h + + out = self.final(out) # (B, 4, L) + out = out.permute(0, 2, 1) # (B, L, 4) + + # Normalization + out = out - out.mean(dim=-1, keepdim=True) + return out + + +class MLPModel(nn.Module): + def __init__( + self, input_dim: int = 128, time_dim: int = 1, hidden_dim=128, length=500): + super().__init__() + self.input_dim = input_dim + self.time_dim = time_dim + self.hidden_dim = hidden_dim + + self.time_embedding = nn.Linear(1, time_dim) + self.token_embedding = torch.nn.Embedding(self.input_dim, hidden_dim) + + self.swish = Swish() + + self.main = nn.Sequential( + self.swish, + nn.Linear(hidden_dim * length + time_dim, hidden_dim), + self.swish, + nn.Linear(hidden_dim, hidden_dim), + self.swish, + nn.Linear(hidden_dim, hidden_dim), + self.swish, + nn.Linear(hidden_dim, self.input_dim * length), + ) + + def forward(self, x, t): + ''' + x shape (B,L) + t shape (B,) + ''' + t = self.time_embedding(t.unsqueeze(-1)) + x = self.token_embedding(x) + + B, N, d = x.shape + x = x.reshape(B, N * d) + + h = torch.cat([x, t], dim=1) + h = self.main(h) + + h = h.reshape(B, N, self.input_dim) + + return h + +class CNNModel(nn.Module): + """A time-dependent score-based model built upon U-Net architecture.""" + + def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256): + """ + Args: + embed_dim (int): Dimensionality of the token and time embeddings. + """ + super().__init__() + self.alphabet_size = alphabet_size + + self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim) + # self.esm = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D") + # self.esm.eval() + # for param in self.esm.parameters(): + # param.requires_grad = False + + self.time_embed = nn.Sequential( + GaussianFourierProjection(embed_dim=embed_dim), + nn.Linear(embed_dim, embed_dim) + ) + + self.swish = Swish() + + n = hidden_dim + + self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4) + + self.blocks = nn.ModuleList([ + nn.Conv1d(n, n, kernel_size=9, padding=4), + nn.Conv1d(n, n, kernel_size=9, padding=4), + nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, padding=4), + # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256) + ]) + + self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(5)]) + self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(5)]) + + self.final = nn.Sequential( + nn.Conv1d(n, n, kernel_size=1), + nn.GELU(), + nn.Conv1d(n, self.alphabet_size, kernel_size=1) + ) + + def forward(self, x, t): + """ + Args: + x: Tensor of shape (B, L) containing DNA token indices. + t: Tensor of shape (B,) containing the time steps. + Returns: + out: Tensor of shape (B, L, 4) with output logits for each DNA base. + """ + x = self.token_embedding(x) # (B, L) -> (B, L, embed_dim) + # with torch.no_grad(): + # x = self.esm(input_ids=x).last_hidden_state + time_embed = self.swish(self.time_embed(t)) # (B, embed_dim) + + out = x.permute(0, 2, 1) # (B, L, embed_dim) -> (B, embed_dim, L) + out = self.swish(self.linear(out)) # (B, n, L) + + # Process through convolutional blocks, adding time conditioning via dense layers. + for block, dense, norm in zip(self.blocks, self.denses, self.norms): + # dense(embed) gives (B, n); unsqueeze to (B, n, 1) for broadcasting. + h = self.swish(block(norm(out + dense(time_embed)[:, :, None]))) + # Residual connection if shapes match. + if h.shape == out.shape: + out = h + out + else: + out = h + + out = self.final(out) # (B, 4, L) + out = out.permute(0, 2, 1) # (B, L, 4) + + # Normalization + out = out - out.mean(dim=-1, keepdim=True) + return out + +class CNNModel_Large(nn.Module): + """A time-dependent score-based model built upon U-Net architecture.""" + + def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256): + """ + Args: + embed_dim (int): Dimensionality of the token and time embeddings. + """ + super().__init__() + self.alphabet_size = alphabet_size + + self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim) + + self.time_embed = nn.Sequential( + GaussianFourierProjection(embed_dim=embed_dim), + nn.Linear(embed_dim, embed_dim) + ) + + self.swish = Swish() + + n = hidden_dim + + self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4) + + self.blocks = nn.ModuleList([ + nn.Conv1d(n, n, kernel_size=9, padding=4), + nn.Conv1d(n, n, kernel_size=9, padding=4), + nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), + nn.Conv1d(n, n, kernel_size=9, padding=4), + nn.Conv1d(n, n, kernel_size=9, padding=4), + nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), + nn.Conv1d(n, n, kernel_size=9, padding=4), + nn.Conv1d(n, n, kernel_size=9, padding=4), + nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), + nn.Conv1d(n, n, kernel_size=9, padding=4), + nn.Conv1d(n, n, kernel_size=9, padding=4), + nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), + nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), + nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256) + ]) + + self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(20)]) + self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(20)]) + + self.final = nn.Sequential( + nn.Conv1d(n, n, kernel_size=1), + nn.GELU(), + nn.Conv1d(n, self.alphabet_size, kernel_size=1) + ) + + def forward(self, x, t): + """ + Args: + x: Tensor of shape (B, L) containing DNA token indices. + t: Tensor of shape (B,) containing the time steps. + Returns: + out: Tensor of shape (B, L, 4) with output logits for each DNA base. + """ + x = self.token_embedding(x) # (B, L) -> (B, L, embed_dim) + time_embed = self.swish(self.time_embed(t)) # (B, embed_dim) + + out = x.permute(0, 2, 1) # (B, L, embed_dim) -> (B, embed_dim, L) + out = self.swish(self.linear(out)) # (B, n, L) + + # Process through convolutional blocks, adding time conditioning via dense layers. + for block, dense, norm in zip(self.blocks, self.denses, self.norms): + # dense(embed) gives (B, n); unsqueeze to (B, n, 1) for broadcasting. + h = self.swish(block(norm(out + dense(time_embed)[:, :, None]))) + # Residual connection if shapes match. + if h.shape == out.shape: + out = h + out + else: + out = h + + out = self.final(out) # (B, 4, L) + out = out.permute(0, 2, 1) # (B, L, 4) + + # Normalization + out = out - out.mean(dim=-1, keepdim=True) + return out \ No newline at end of file diff --git a/modules/bindevaluator_modules/__init__.py b/modules/bindevaluator_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0175e72e9d2cbada6c5e3d8f10c57a34192a5bb0 --- /dev/null +++ b/modules/bindevaluator_modules/__init__.py @@ -0,0 +1,3 @@ +from .models import * +from .score_domain import * +from .dataloaders import * diff --git a/modules/bindevaluator_modules/__pycache__/__init__.cpython-310.pyc b/modules/bindevaluator_modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..375d081c8d0e3b2c6c45afb9362ef5f65cc289e0 Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/__init__.cpython-38.pyc b/modules/bindevaluator_modules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0b4f7356c3757dc6cd62fa2536570b307264750 Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/__init__.cpython-39.pyc b/modules/bindevaluator_modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03496b75ca230a5fe3fbdaf1fea1f8a9a91f6e1a Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/dataloaders.cpython-310.pyc b/modules/bindevaluator_modules/__pycache__/dataloaders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec6c5dd6abafb90dd6cd8b03d6aca48d2ec0bcf5 Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/dataloaders.cpython-310.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/dataloaders.cpython-38.pyc b/modules/bindevaluator_modules/__pycache__/dataloaders.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66e827c277a2d52ad0166efb98478ffd9eecd909 Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/dataloaders.cpython-38.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/dataloaders.cpython-39.pyc b/modules/bindevaluator_modules/__pycache__/dataloaders.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29040c38fe478a3d2c5e50ef00f03a1bed73cf45 Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/dataloaders.cpython-39.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/layers.cpython-310.pyc b/modules/bindevaluator_modules/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daca436dc18593bb778268eae02cbc267ea91077 Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/layers.cpython-310.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/layers.cpython-38.pyc b/modules/bindevaluator_modules/__pycache__/layers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98676571e287931972131ddbc57e114106581ed1 Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/layers.cpython-38.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/layers.cpython-39.pyc b/modules/bindevaluator_modules/__pycache__/layers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86b674cd51757e10650eb17e1481c1ec40022b64 Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/layers.cpython-39.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/models.cpython-310.pyc b/modules/bindevaluator_modules/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b1248c69391216aca77e784f942111c423c76df Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/models.cpython-310.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/models.cpython-38.pyc b/modules/bindevaluator_modules/__pycache__/models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d597292a702c705120ef9eb07fb4a32414a29b27 Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/models.cpython-38.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/models.cpython-39.pyc b/modules/bindevaluator_modules/__pycache__/models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1aea3fea4d8103ace699d1c3eba7d09633ab5180 Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/models.cpython-39.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/modules.cpython-310.pyc b/modules/bindevaluator_modules/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43e42a397f0926356f21a35a7710be104215cf2b Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/modules.cpython-310.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/modules.cpython-38.pyc b/modules/bindevaluator_modules/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8969d15ca3545aae9e8de7c5b0aab8de918a264c Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/modules.cpython-38.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/modules.cpython-39.pyc b/modules/bindevaluator_modules/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f751a53e8117c35ea2cf1c1939b2595c562afdc9 Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/modules.cpython-39.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/score_domain.cpython-310.pyc b/modules/bindevaluator_modules/__pycache__/score_domain.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7d933e2dd20727acdde19ea4801a1893732cd47 Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/score_domain.cpython-310.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/score_domain.cpython-38.pyc b/modules/bindevaluator_modules/__pycache__/score_domain.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd1d372c40837120f7c4e1777a6fedae192c8c39 Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/score_domain.cpython-38.pyc differ diff --git a/modules/bindevaluator_modules/__pycache__/score_domain.cpython-39.pyc b/modules/bindevaluator_modules/__pycache__/score_domain.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b68026f9e83b7e678e7b3cae1bf44f38bdd120a4 Binary files /dev/null and b/modules/bindevaluator_modules/__pycache__/score_domain.cpython-39.pyc differ diff --git a/modules/bindevaluator_modules/dataloaders.py b/modules/bindevaluator_modules/dataloaders.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d9deaf0ab33682d781b94af70f54d8d0883c8e --- /dev/null +++ b/modules/bindevaluator_modules/dataloaders.py @@ -0,0 +1,426 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Jul 31 21:54:08 2021 + +@author: Osama +""" + +from torch.utils.data import Dataset +from Bio.PDB import Polypeptide +import numpy as np +import torch +import pandas as pd +import os +# import esm +import ast +import pdb + + +class InterpepComplexes(Dataset): + + def __init__(self, mode, + encoded_data_directory = "../../datasets/interpep_data/"): + + self.mode = mode + + self.encoded_data_directory = encoded_data_directory + + self.train_dir = "../../datasets/interpep_data/train_examples.npy" + + self.test_dir = "../../datasets/interpep_data/test_examples.npy" + + self.val_dir = "../../datasets/interpep_data/val_examples.npy" + + + self.test_list = np.load(self.test_dir) + + self.train_list = np.load(self.train_dir) + + self.val_list = np.load(self.val_dir) + + + + if mode == "train": + self.num_data = len(self.train_list) + elif mode == "val": + self.num_data = len(self.val_list) + elif mode == "test": + self.num_data = len(self.test_list) + + + + def __getitem__(self, index): + + if self.mode == "train": + item = self.train_list[index] + elif self.mode == "val": + item = self.val_list[index] + elif self.mode == "test": + item = self.test_list[index] + + file_dir = self.encoded_data_directory + + with np.load(file_dir + "fragment_data/" + item + ".npz") as data: + temp_pep_sequence = data["target_sequence"] + temp_binding_sites = data["binding_sites"] + + + with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\ + item.split("_")[1] + ".npz") as data: + temp_nodes = data["nodes"] + + + binding = np.zeros(len(temp_nodes)) + if len(temp_binding_sites) != 0: + binding[temp_binding_sites] = 1 + target = torch.LongTensor(binding) + + + + + + + + nodes = temp_nodes[:, 0:20] + + prot_sequence = np.argmax(nodes, axis=-1) + + + + prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence]) + + + + pep_sequence = temp_pep_sequence + + pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1) + + + + + + return pep_sequence, prot_sequence, target + + def __len__(self): + return self.num_data + +class PPI(Dataset): + + def __init__(self, mode, csv_dir_path = "/home/u21307130002/PepNN/pepnn/datasets/ppi/"): + + self.mode = mode + self.train_data = pd.read_csv(os.path.join(csv_dir_path, 'train.csv')) + self.val_data = pd.read_csv(os.path.join(csv_dir_path, 'val.csv')) + # self.test_data = pd.read_csv(os.path.join(csv_dir_path, 'test.csv')) + + if self.mode == 'train': + self.num_data = len(self.train_data) + + def __len__(self): + return self.num_data + + def __getitem__(self, index): + # pdb.set_trace() + if torch.is_tensor(index): + index = index.tolist() + + if self.mode == "train": + item = self.train_data.iloc[index] + elif self.mode == "val": + item = self.val_data.iloc[index] + elif self.mode == "test": + item = self.test_data.iloc[index] + else: + item = None + + # print(item) + + motif1 = ast.literal_eval(item['Chain_1_motifs']) + motif2 = ast.literal_eval(item['Chain_2_motifs']) + + if len(motif1[0]) > len(motif2[0]): + target = motif1 + prot_sequence = item['Sequence1'] + pep_sequence = item['Sequence2'] + else: + target = motif2 + pep_sequence = item['Sequence1'] + prot_sequence = item['Sequence2'] + + target = [int(motif.split('_')[1]) for motif in target] + + if target[-1] >= len(prot_sequence): + pdb.set_trace() + + binding = np.zeros(len(prot_sequence)) + if len(target) != 0: + binding[target] = 1 + target = torch.LongTensor(binding).float() + + # print(f"peptide length: {len(pep_sequence)}") + # print(f"protein length: {len(prot_sequence)}") + # print(f"target length: {len(target)}") + # pdb.set_trace() + + return pep_sequence, prot_sequence, target + + + + +class PepBindComplexes(Dataset): + + def __init__(self, mode, + encoded_data_directory = "../../datasets/pepbind_data/"): + + self.mode = mode + + self.encoded_data_directory = encoded_data_directory + + self.train_dir = "../../datasets/pepbind_data/train_examples.npy" + + self.test_dir = "../../datasets/pepbind_data/test_examples.npy" + + self.val_dir = "../../datasets/pepbind_data/val_examples.npy" + + + self.test_list = np.load(self.test_dir) + + self.train_list = np.load(self.train_dir) + + self.val_list = np.load(self.val_dir) + + + if mode == "train": + self.num_data = len(self.train_list) + elif mode == "val": + self.num_data = len(self.val_list) + elif mode == "test": + self.num_data = len(self.test_list) + + + + def __getitem__(self, index): + + if self.mode == "train": + item = self.train_list[index] + + + elif self.mode == "val": + item = self.val_list[index] + + + elif self.mode == "test": + item = self.test_list[index] + + + + file_dir = self.encoded_data_directory + + + with np.load(file_dir + "fragment_data/" + item + ".npz") as data: + temp_pep_sequence = data["target_sequence"] + temp_binding_sites = data["binding_sites"] + + + with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\ + item.split("_")[1] + ".npz") as data: + temp_nodes = data["nodes"] + + + binding = np.zeros(len(temp_nodes)) + if len(temp_binding_sites) != 0: + binding[temp_binding_sites] = 1 + target = torch.LongTensor(binding) + + nodes = temp_nodes[:, 0:20] + + prot_sequence = np.argmax(nodes, axis=-1) + + + prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence]) + + + pep_sequence = temp_pep_sequence + + pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1) + + + return pep_sequence, prot_sequence, target + + + def __len__(self): + return self.num_data + +class PeptideComplexes(Dataset): + + def __init__(self, mode, + encoded_data_directory = "../../datasets/pepnn_data/all_data/"): + + self.mode = mode + + self.encoded_data_directory = encoded_data_directory + + self.train_dir = "../../datasets/pepnn_data/train_examples.npy" + + self.test_dir = "../../datasets/pepnn_test_data/test_examples.npy" + + self.val_dir = "../../datasets/pepnn_data/val_examples.npy" + + + self.example_weights = np.load("../../datasets/pepnn_data/example_weights.npy") + + self.test_list = np.load(self.test_dir) + + self.train_list = np.load(self.train_dir) + + self.val_list = np.load(self.val_dir) + + + + if mode == "train": + self.num_data = len(self.train_list) + elif mode == "val": + self.num_data = len(self.val_list) + elif mode == "test": + self.num_data = len(self.test_list) + + + + def __getitem__(self, index): + + + if self.mode == "train": + item = self.train_list[index] + + weight = self.example_weights[item] + + elif self.mode == "val": + item = self.val_list[index] + + weight = self.example_weights[item] + + elif self.mode == "test": + item = self.test_list[index] + + weight = 1 + + if self.mode != "test": + file_dir = self.encoded_data_directory + else: + file_dir = "../../datasets/pepnn_test_data/all_data/" + + + with np.load(file_dir + "fragment_data/" + item + ".npz") as data: + temp_pep_sequence = data["target_sequence"] + temp_binding_sites = data["binding_sites"] + + + with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\ + item.split("_")[1] + ".npz") as data: + temp_nodes = data["nodes"] + + + binding = np.zeros(len(temp_nodes)) + if len(temp_binding_sites) != 0: + binding[temp_binding_sites] = 1 + target = torch.LongTensor(binding) + + + + + + + + nodes = temp_nodes[:, 0:20] + + prot_sequence = np.argmax(nodes, axis=-1) + + + + prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence]) + + + + pep_sequence = temp_pep_sequence + + pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1) + + + + + + return pep_sequence, prot_sequence, target, weight + + + def __len__(self): + return self.num_data + + +class BitenetComplexes(Dataset): + + def __init__(self, encoded_data_directory = "../bitenet_data/all_data/"): + + + self.encoded_data_directory = encoded_data_directory + + + + + self.train_dir = "../../datasets/bitenet_data/examples.npy" + + + + + self.full_list = np.load(self.train_dir) + + + + + self.num_data = len(self.full_list) + + + + + def __getitem__(self, index): + + item = self.full_list[index] + + file_dir = self.encoded_data_directory + + with np.load(file_dir + "fragment_data/" + item[:-1] + "_" + item[-1] + ".npz") as data: + temp_pep_sequence = data["target_sequence"] + temp_binding_matrix = data["binding_matrix"] + + + with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\ + item.split("_")[1][0] + ".npz") as data: + temp_nodes = data["nodes"] + + + binding_sum = np.sum(temp_binding_matrix, axis=0).T + + target = torch.LongTensor(binding_sum >= 1) + + + + nodes = temp_nodes[:, 0:20] + + prot_sequence = np.argmax(nodes, axis=-1) + + + + prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence]) + + + + pep_sequence = temp_pep_sequence + + pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1) + + + + + return pep_sequence, prot_sequence, target + + def __len__(self): + return self.num_data \ No newline at end of file diff --git a/modules/bindevaluator_modules/layers.py b/modules/bindevaluator_modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..179fa1ce31ecb11506e06430a7a70adc159e9565 --- /dev/null +++ b/modules/bindevaluator_modules/layers.py @@ -0,0 +1,142 @@ +from torch import nn +from .modules import * +import pdb + +class ConvLayer(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding, dilation): + super(ConvLayer, self).__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation) + self.relu = nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + +class DilatedCNN(nn.Module): + def __init__(self, d_model, d_hidden): + super(DilatedCNN, self).__init__() + self.first_ = nn.ModuleList() + self.second_ = nn.ModuleList() + self.third_ = nn.ModuleList() + + dilation_tuple = (1, 2, 3) + dim_in_tuple = (d_model, d_hidden, d_hidden) + dim_out_tuple = (d_hidden, d_hidden, d_hidden) + + for i, dilation_rate in enumerate(dilation_tuple): + self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate, + dilation=dilation_rate)) + + for i, dilation_rate in enumerate(dilation_tuple): + self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate, + dilation=dilation_rate)) + + for i, dilation_rate in enumerate(dilation_tuple): + self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate, + dilation=dilation_rate)) + + def forward(self, protein_seq_enc): + # pdb.set_trace() + protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L + + first_embedding = protein_seq_enc + second_embedding = protein_seq_enc + third_embedding = protein_seq_enc + + for i in range(len(self.first_)): + first_embedding = self.first_[i](first_embedding) + + for i in range(len(self.second_)): + second_embedding = self.second_[i](second_embedding) + + for i in range(len(self.third_)): + third_embedding = self.third_[i](third_embedding) + + # pdb.set_trace() + + protein_seq_enc = first_embedding + second_embedding + third_embedding + + return protein_seq_enc.transpose(1, 2) + + +class ReciprocalLayerwithCNN(nn.Module): + + def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v): + super().__init__() + + self.cnn = DilatedCNN(d_model, d_hidden) + + self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, + d_k, d_v) + + self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, + d_k, d_v) + + self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, + d_k, d_v) + + self.ffn_seq = FFN(d_hidden, d_inner) + + self.ffn_protein = FFN(d_hidden, d_inner) + + def forward(self, sequence_enc, protein_seq_enc): + # pdb.set_trace() # protein_seq_enc.shape = B * L * d_model + protein_seq_enc = self.cnn(protein_seq_enc) + prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc) + + seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc) + + prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, + seq_enc, + seq_enc, + prot_enc) + prot_enc = self.ffn_protein(prot_enc) + + seq_enc = self.ffn_seq(seq_enc) + + return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention + + +class ReciprocalLayer(nn.Module): + + def __init__(self, d_model, d_inner, n_head, d_k, d_v): + + super().__init__() + + self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, + d_k, d_v) + + self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, + d_k, d_v) + + self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, + d_k, d_v) + + + + self.ffn_seq = FFN(d_model, d_inner) + + self.ffn_protein = FFN(d_model, d_inner) + + def forward(self, sequence_enc, protein_seq_enc): + prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc) + + seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc) + + + prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, + seq_enc, + seq_enc, + prot_enc) + prot_enc = self.ffn_protein(prot_enc) + + seq_enc = self.ffn_seq(seq_enc) + + + + return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention + + + diff --git a/modules/bindevaluator_modules/models.py b/modules/bindevaluator_modules/models.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0a5f672a6767de6ab6ba0b6f759a56167a9a78 --- /dev/null +++ b/modules/bindevaluator_modules/models.py @@ -0,0 +1,284 @@ +import pdb + +import numpy as np +import torch +import torch.nn as nn +from .layers import * +from .modules import * +import pdb +from transformers import EsmModel, EsmTokenizer + +def to_var(x): + if torch.cuda.is_available(): + x = x.cuda() + return x + + +class RepeatedModule3(nn.Module): + def __init__(self, n_layers, d_model, d_hidden, + n_head, d_k, d_v, d_inner, dropout=0.1): + super().__init__() + + self.linear1 = nn.Linear(1280, d_model) + self.linear2 = nn.Linear(1280, d_model) + self.sequence_embedding = nn.Embedding(20, d_model) + self.d_model = d_model + + self.reciprocal_layer_stack = nn.ModuleList([ + ReciprocalLayerwithCNN(d_model, d_inner, d_hidden, n_head, d_k, d_v) + for _ in range(n_layers)]) + + self.dropout = nn.Dropout(dropout) + self.dropout_2 = nn.Dropout(dropout) + + def forward(self, peptide_sequence, protein_sequence): + sequence_attention_list = [] + + prot_attention_list = [] + + prot_seq_attention_list = [] + + seq_prot_attention_list = [] + + sequence_enc = self.dropout(self.linear1(peptide_sequence)) + + prot_enc = self.dropout_2(self.linear2(protein_sequence)) + + for reciprocal_layer in self.reciprocal_layer_stack: + prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \ + reciprocal_layer(sequence_enc, prot_enc) + + sequence_attention_list.append(sequence_attention) + + prot_attention_list.append(prot_attention) + + prot_seq_attention_list.append(prot_seq_attention) + + seq_prot_attention_list.append(seq_prot_attention) + + return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ + seq_prot_attention_list, seq_prot_attention_list + + +class RepeatedModule2(nn.Module): + def __init__(self, n_layers, d_model, + n_head, d_k, d_v, d_inner, dropout=0.1): + super().__init__() + + self.linear1 = nn.Linear(1280, d_model) + self.linear2 = nn.Linear(1280, d_model) + self.sequence_embedding = nn.Embedding(20, d_model) + self.d_model = d_model + + self.reciprocal_layer_stack = nn.ModuleList([ + ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v) + for _ in range(n_layers)]) + + self.dropout = nn.Dropout(dropout) + self.dropout_2 = nn.Dropout(dropout) + + def forward(self, peptide_sequence, protein_sequence): + sequence_attention_list = [] + + prot_attention_list = [] + + prot_seq_attention_list = [] + + seq_prot_attention_list = [] + + sequence_enc = self.dropout(self.linear1(peptide_sequence)) + + prot_enc = self.dropout_2(self.linear2(protein_sequence)) + + for reciprocal_layer in self.reciprocal_layer_stack: + prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \ + reciprocal_layer(sequence_enc, prot_enc) + + sequence_attention_list.append(sequence_attention) + + prot_attention_list.append(prot_attention) + + prot_seq_attention_list.append(prot_seq_attention) + + seq_prot_attention_list.append(seq_prot_attention) + + return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ + seq_prot_attention_list, seq_prot_attention_list + + +class RepeatedModule(nn.Module): + + def __init__(self, n_layers, d_model, + n_head, d_k, d_v, d_inner, dropout=0.1): + + super().__init__() + + self.linear = nn.Linear(1024, d_model) + self.sequence_embedding = nn.Embedding(20, d_model) + self.d_model = d_model + + self.reciprocal_layer_stack = nn.ModuleList([ + ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v) + for _ in range(n_layers)]) + + self.dropout = nn.Dropout(dropout) + self.dropout_2 = nn.Dropout(dropout) + + + + def _positional_embedding(self, batches, number): + + result = torch.exp(torch.arange(0, self.d_model,2,dtype=torch.float32)*-1*(np.log(10000)/self.d_model)) + + numbers = torch.arange(0, number, dtype=torch.float32) + + numbers = numbers.unsqueeze(0) + + numbers = numbers.unsqueeze(2) + + result = numbers*result + + result = torch.cat((torch.sin(result), torch.cos(result)),2) + + return result + + def forward(self, peptide_sequence, protein_sequence): + + + sequence_attention_list = [] + + prot_attention_list = [] + + prot_seq_attention_list = [] + + seq_prot_attention_list = [] + + sequence_enc = self.sequence_embedding(peptide_sequence) + + sequence_enc += to_var(self._positional_embedding(peptide_sequence.shape[0], + peptide_sequence.shape[1])) + sequence_enc = self.dropout(sequence_enc) + + + + + + prot_enc = self.dropout_2(self.linear(protein_sequence)) + + + + + for reciprocal_layer in self.reciprocal_layer_stack: + + prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention =\ + reciprocal_layer(sequence_enc, prot_enc) + + sequence_attention_list.append(sequence_attention) + + prot_attention_list.append(prot_attention) + + prot_seq_attention_list.append(prot_seq_attention) + + seq_prot_attention_list.append(seq_prot_attention) + + + + return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\ + seq_prot_attention_list, seq_prot_attention_list + + +class FullModel(nn.Module): + + def __init__(self, n_layers, d_model, n_head, + d_k, d_v, d_inner, return_attention=False, dropout=0.2): + super().__init__() + + self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") + + # freeze all the esm_model parameters + for param in self.esm_model.parameters(): + param.requires_grad = False + + self.repeated_module = RepeatedModule2(n_layers, d_model, + n_head, d_k, d_v, d_inner, dropout=dropout) + + self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, + d_k, d_v, dropout=dropout) + + self.final_ffn = FFN(d_model, d_inner, dropout=dropout) + + self.output_projection_prot = nn.Linear(d_model, 1) + self.sigmoid = nn.Sigmoid() + + self.return_attention = return_attention + + def forward(self, binder_tokens, target_tokens): + + with torch.no_grad(): + peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state + protein_sequence = self.esm_model(**target_tokens).last_hidden_state + + # pdb.set_trace() + + prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ + seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, + protein_sequence) + + prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) + + # pdb.set_trace() + + prot_enc = self.final_ffn(prot_enc) + + prot_enc = self.sigmoid(self.output_projection_prot(prot_enc)) + + return prot_enc + + + +class Original_FullModel(nn.Module): + + def __init__(self, n_layers, d_model, n_head, + d_k, d_v, d_inner, return_attention=False, dropout=0.2): + + super().__init__() + self.repeated_module = RepeatedModule(n_layers, d_model, + n_head, d_k, d_v, d_inner, dropout=dropout) + + self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, + d_k, d_v, dropout=dropout) + + self.final_ffn = FFN(d_model, d_inner, dropout=dropout) + self.output_projection_prot = nn.Linear(d_model, 2) + + + + self.softmax_prot =nn.LogSoftmax(dim=-1) + + + self.return_attention = return_attention + + def forward(self, peptide_sequence, protein_sequence): + + prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\ + seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, + protein_sequence) + + + + prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) + + prot_enc = self.final_ffn(prot_enc) + + prot_enc = self.softmax_prot(self.output_projection_prot(prot_enc)) + + + + + + if not self.return_attention: + return prot_enc + else: + return prot_enc, sequence_attention_list, prot_attention_list,\ + seq_prot_attention_list, seq_prot_attention_list + diff --git a/modules/bindevaluator_modules/modules.py b/modules/bindevaluator_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..63a83d8f24e4aae9b8f52957920a5cd1b2755f1e --- /dev/null +++ b/modules/bindevaluator_modules/modules.py @@ -0,0 +1,187 @@ +from torch import nn +import numpy as np +import torch +import torch.nn.functional as F + + +def to_var(x): + if torch.cuda.is_available(): + x = x.cuda() + return x + + + + + +class MultiHeadAttentionSequence(nn.Module): + + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + + super().__init__() + + self.n_head = n_head + self.d_model = d_model + self.d_k = d_k + self.d_v = d_v + + self.W_Q = nn.Linear(d_model, n_head*d_k) + self.W_K = nn.Linear(d_model, n_head*d_k) + self.W_V = nn.Linear(d_model, n_head*d_v) + self.W_O = nn.Linear(n_head*d_v, d_model) + + self.layer_norm = nn.LayerNorm(d_model) + + self.dropout = nn.Dropout(dropout) + + def forward(self, q, k, v): + + batch, len_q, _ = q.size() + batch, len_k, _ = k.size() + batch, len_v, _ = v.size() + + Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k]) + K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k]) + V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v]) + + Q = Q.transpose(1, 2) + K = K.transpose(1, 2).transpose(2, 3) + V = V.transpose(1, 2) + + attention = torch.matmul(Q, K) + + attention = attention / np.sqrt(self.d_k) + + attention = F.softmax(attention, dim=-1) + + output = torch.matmul(attention, V) + + output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head]) + + output = self.W_O(output) + + output = self.dropout(output) + + output = self.layer_norm(output + q) + + return output, attention + +class MultiHeadAttentionReciprocal(nn.Module): + + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + + super().__init__() + + self.n_head = n_head + self.d_model = d_model + self.d_k = d_k + self.d_v = d_v + + + self.W_Q = nn.Linear(d_model, n_head*d_k) + self.W_K = nn.Linear(d_model, n_head*d_k) + self.W_V = nn.Linear(d_model, n_head*d_v) + self.W_O = nn.Linear(n_head*d_v, d_model) + self.W_V_2 = nn.Linear(d_model, n_head*d_v) + self.W_O_2 = nn.Linear(n_head*d_v, d_model) + + self.layer_norm = nn.LayerNorm(d_model) + + self.dropout = nn.Dropout(dropout) + + self.layer_norm_2 = nn.LayerNorm(d_model) + + self.dropout_2 = nn.Dropout(dropout) + + + + + def forward(self, q, k, v, v_2): + + batch, len_q, _ = q.size() + batch, len_k, _ = k.size() + batch, len_v, _ = v.size() + batch, len_v_2, _ = v_2.size() + + + Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k]) + K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k]) + V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v]) + V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v]) + + + + Q = Q.transpose(1, 2) + K = K.transpose(1, 2).transpose(2, 3) + V = V.transpose(1, 2) + V_2 = V_2.transpose(1,2) + + attention = torch.matmul(Q, K) + + + attention = attention /np.sqrt(self.d_k) + + attention_2 = attention.transpose(-2, -1) + + + + attention = F.softmax(attention, dim=-1) + + attention_2 = F.softmax(attention_2, dim=-1) + + + output = torch.matmul(attention, V) + + output_2 = torch.matmul(attention_2, V_2) + + output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head]) + + output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head]) + + output = self.W_O(output) + + output_2 = self.W_O_2(output_2) + + output = self.dropout(output) + + output = self.layer_norm(output + q) + + output_2 = self.dropout(output_2) + + output_2 = self.layer_norm(output_2 + k) + + + + + + return output, output_2, attention, attention_2 + + +class FFN(nn.Module): + + def __init__(self, d_in, d_hid, dropout=0.1): + super().__init__() + + self.layer_1 = nn.Conv1d(d_in, d_hid,1) + self.layer_2 = nn.Conv1d(d_hid, d_in,1) + self.relu = nn.ReLU() + self.layer_norm = nn.LayerNorm(d_in) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + + residual = x + output = self.layer_1(x.transpose(1, 2)) + + output = self.relu(output) + + output = self.layer_2(output) + + output = self.dropout(output) + + output = self.layer_norm(output.transpose(1, 2)+residual) + + return output + diff --git a/modules/bindevaluator_modules/score_domain.py b/modules/bindevaluator_modules/score_domain.py new file mode 100644 index 0000000000000000000000000000000000000000..17c56ec8b52dbed7cef4cebb1cbb5b92b091627e --- /dev/null +++ b/modules/bindevaluator_modules/score_domain.py @@ -0,0 +1,40 @@ +from scipy.stats import norm +import numpy as np +import os + + +def score(outputs): + + weight = 0.03 + binding_size_dist = np.load(os.path.join(os.path.dirname(__file__), "../params/binding_size_train_dist.npy")) + + + mean = np.mean(binding_size_dist) + + std = np.std(binding_size_dist) + + dist = norm(mean, std) + + + max_score = 0 + + + + scores = np.exp(outputs[0])[:, 1] + + indices = np.argsort(-1*scores) + + for j in range(1, len(indices)): + + + + score = (1-weight)*np.mean(scores[indices[:j]]) + weight*(dist.pdf(j/len(indices))) + + + if score > max_score: + + max_score = score + + + return max_score + \ No newline at end of file diff --git a/modules/dna_module.py b/modules/dna_module.py new file mode 100644 index 0000000000000000000000000000000000000000..01533f9eab6350ef5f3f0d31a286e31ae0cc9bba --- /dev/null +++ b/modules/dna_module.py @@ -0,0 +1,301 @@ +import copy +import math +from collections import defaultdict + +import PIL +import numpy as np +import pandas as pd +import torch, time, os +import wandb +import seaborn as sns +import yaml + +sns.set_style('whitegrid') +from matplotlib import pyplot as plt +from torch import optim + +from models.dna_models import MLPModel, CNNModel, TransformerModel, DeepFlyBrainModel +from utils.flow_utils import DirichletConditionalFlow, expand_simplex, sample_cond_prob_path, simplex_proj, \ + get_wasserstein_dist, update_ema, load_flybrain_designed_seqs +from modules.general_module import GeneralModule +from utils.log import get_logger + +from flow_matching.path import MixtureDiscreteProbPath +from flow_matching.path.scheduler import PolynomialConvexScheduler +from flow_matching.solver import MixtureDiscreteEulerSolver +from flow_matching.utils import ModelWrapper +from flow_matching.loss import MixturePathGeneralizedKL + +import pdb + + +logger = get_logger(__name__) + + +class DNAModule(GeneralModule): + def __init__(self, args, alphabet_size, num_cls, source_distribution="uniform"): + super().__init__(args) + self.alphabet_size = alphabet_size + self.source_distribution = source_distribution + self.epsilon = 1e-3 + + if source_distribution == "uniform": + added_token = 0 + elif source_distribution == "mask": + self.mask_token = alphabet_size # tokens starting from zero + added_token = 1 + else: + raise NotImplementedError + self.alphabet_size += added_token + + self.load_model(self.alphabet_size, num_cls) + + self.scheduler = PolynomialConvexScheduler(n=args.scheduler_n) + self.path = MixtureDiscreteProbPath(scheduler=self.scheduler) + self.loss_fn = MixturePathGeneralizedKL(path=self.path) + + self.val_outputs = defaultdict(list) + self.train_outputs = defaultdict(list) + self.train_out_initialized = False + self.mean_log_ema = {} + if self.args.taskiran_seq_path is not None: + self.taskiran_fly_seqs = load_flybrain_designed_seqs(self.args.taskiran_seq_path).to(self.device) + + def on_load_checkpoint(self, checkpoint): + checkpoint['state_dict'] = {k: v for k,v in checkpoint['state_dict'].items() if 'cls_model' not in k and 'distill_model' not in k} + + def training_step(self, batch, batch_idx): + self.stage = 'train' + loss = self.general_step(batch, batch_idx) + if self.args.ckpt_iterations is not None and self.trainer.global_step in self.args.ckpt_iterations: + self.trainer.save_checkpoint(os.path.join(os.environ["MODEL_DIR"],f"epoch={self.trainer.current_epoch}-step={self.trainer.global_step}.ckpt")) + # self.try_print_log() + return loss + + def validation_step(self, batch, batch_idx): + self.stage = 'val' + loss = self.general_step(batch, batch_idx) + # if self.args.validate: + # self.try_print_log() + + def general_step(self, batch, batch_idx=None): + self.iter_step += 1 + x_1, cls = batch + B, L = x_1.shape + x_1 = x_1.to(self.device) + + if self.source_distribution == "uniform": + x_0 = torch.randint_like(x_1, high=self.alphabet_size) + elif self.source_distribution == "mask": + x_0 = torch.zeros_like(x_1) + self.mask_token + else: + raise NotImplementedError + # pdb.set_trace() + t = torch.rand(x_1.shape[0]) * (1 - self.epsilon) + t = t.to(x_1.device) + path_sample = self.path.sample(t=t, x_0=x_0, x_1=x_1) + + logits = self.model(x_t=path_sample.x_t, t=path_sample.t) + loss = self.loss_fn(logits=logits, x_1=x_1, x_t=path_sample.x_t, t=path_sample.t) + # pdb.set_trace() + + self.lg('loss', loss) + if self.stage == "val": + predicted = logits.argmax(dim=-1) + accuracy = (predicted == x_1).float().mean() + self.lg('acc', accuracy) + self.last_log_time = time.time() + return loss + + @torch.no_grad() + def dirichlet_flow_inference(self, seq, cls, model, args): + B, L = seq.shape + K = model.alphabet_size + x0 = torch.distributions.Dirichlet(torch.ones(B, L, model.alphabet_size, device=seq.device)).sample() + eye = torch.eye(K).to(x0) + xt = x0.clone() + + t_span = torch.linspace(1, args.alpha_max, self.args.num_integration_steps, device=self.device) + for i, (s, t) in enumerate(zip(t_span[:-1], t_span[1:])): + xt_expanded, prior_weights = expand_simplex(xt, s[None].expand(B), args.prior_pseudocount) + + logits = model(xt_expanded, t=s[None].expand(B)) + flow_probs = torch.nn.functional.softmax(logits / args.flow_temp, -1) # [B, L, K] + + if not torch.allclose(flow_probs.sum(2), torch.ones((B, L), device=self.device), atol=1e-4) or not (flow_probs >= 0).all(): + print(f'WARNING: flow_probs.min(): {flow_probs.min()}. Some values of flow_probs do not lie on the simplex. There are we are {(flow_probs<0).sum()} negative values in flow_probs of shape {flow_probs.shape} that are negative. We are projecting them onto the simplex.') + flow_probs = simplex_proj(flow_probs) + + c_factor = self.condflow.c_factor(xt.cpu().numpy(), s.item()) + c_factor = torch.from_numpy(c_factor).to(xt) + + self.inf_counter += 1 + + if not (flow_probs >= 0).all(): print(f'flow_probs.min(): {flow_probs.min()}') + cond_flows = (eye - xt.unsqueeze(-1)) * c_factor.unsqueeze(-2) + flow = (flow_probs.unsqueeze(-2) * cond_flows).sum(-1) + + xt = xt + flow * (t - s) + + if not torch.allclose(xt.sum(2), torch.ones((B, L), device=self.device), atol=1e-4) or not (xt >= 0).all(): + print(f'WARNING: xt.min(): {xt.min()}. Some values of xt do not lie on the simplex. There are we are {(xt<0).sum()} negative values in xt of shape {xt.shape} that are negative. We are projecting them onto the simplex.') + xt = simplex_proj(xt) + return logits, x0 + + def on_validation_epoch_start(self): + self.inf_counter = 1 + self.nan_inf_counter = 0 + + def on_validation_epoch_end(self): + self.generator = np.random.default_rng() + log = self._log + log = {key: log[key] for key in log if "val_" in key} + log = self.gather_log(log, self.trainer.world_size) + mean_log = self.get_log_mean(log) + mean_log.update({'val_nan_inf_step_fraction': self.nan_inf_counter / self.inf_counter}) + + mean_log.update({'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) + + self.mean_log_ema = update_ema(current_dict=mean_log, prev_ema=self.mean_log_ema, gamma=0.9) + mean_log.update(self.mean_log_ema) + if self.trainer.is_global_zero: + logger.info(str(mean_log)) + self.log_dict(mean_log, batch_size=1) + if self.args.wandb: + wandb.log(mean_log) + + path = os.path.join(os.environ["MODEL_DIR"], f"val_{self.trainer.global_step}.csv") + pd.DataFrame(log).to_csv(path) + + for key in list(log.keys()): + if "val_" in key: + del self._log[key] + self.val_outputs = defaultdict(list) + + + def on_train_epoch_start(self) -> None: + self.inf_counter = 1 + self.nan_inf_counter = 0 + # if not self.loaded_distill_model and self.args.distill_ckpt is not None: + # self.load_distill_model() + # self.loaded_distill_model = True + # if not self.loaded_classifiers: + # self.load_classifiers(load_cls=self.args.cls_ckpt is not None, load_clean_cls=self.args.clean_cls_ckpt is not None) + # self.loaded_classifiers = True + + def on_train_epoch_end(self): + self.train_out_initialized = True + log = self._log + log = {key: log[key] for key in log if "train_" in key} + log = self.gather_log(log, self.trainer.world_size) + mean_log = self.get_log_mean(log) + mean_log.update( + {'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) + + if self.trainer.is_global_zero: + logger.info(str(mean_log)) + self.log_dict(mean_log, batch_size=1) + if self.args.wandb: + wandb.log(mean_log) + + for key in list(log.keys()): + if "train_" in key: + del self._log[key] + + def lg(self, key, data): + if isinstance(data, torch.Tensor): + data = data.detach().cpu().numpy() + log = self._log + if self.args.validate or self.stage == 'train': + log["iter_" + key].append(data) + log[self.stage + "_" + key].append(data) + + def configure_optimizers(self): + optimizer = optim.Adam(self.parameters(), lr=self.args.lr) + return optimizer + + def plot_empirical_and_true(self, empirical_dist, true_dist): + num_datasets_to_plot = min(4, empirical_dist.shape[0]) + width = 1 + # Creating a figure and axes + fig, axes = plt.subplots(math.ceil(num_datasets_to_plot/2), 2, figsize=(10, 8)) + for i in range(num_datasets_to_plot): + row, col = i // 2, i % 2 + x = np.arange(len(empirical_dist[i])) + axes[row, col].bar(x, empirical_dist[i], width, label=f'empirical') + axes[row, col].plot(x, true_dist[i], label=f'true density', color='orange') + axes[row, col].legend() + axes[row, col].set_title(f'Sequence position {i + 1}') + axes[row, col].set_xlabel('Category') + axes[row, col].set_ylabel('Density') + plt.tight_layout() + fig.canvas.draw() + pil_img = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) + plt.close() + return pil_img + + def load_model(self, alphabet_size, num_cls): + if self.args.model == 'cnn': + self.model = CNNModel(self.args, alphabet_size=alphabet_size) + elif self.args.model == 'mlp': + self.model = MLPModel(input_dim=alphabet_size, time_dim=1, hidden_dim=self.args.hidden_dim, length=self.args.length) + elif self.args.model == 'transformer': + self.model = TransformerModel(alphabet_size=alphabet_size, seq_length=self.args.length, embed_dim=self.args.hidden_dim, \ + num_layers=self.args.num_layers, num_heads=self.args.num_heads, dropout=self.args.dropout) + elif self.args.model == 'deepflybrain': + self.model = DeepFlyBrainModel(self.args, alphabet_size=alphabet_size,num_cls=num_cls) + else: + raise NotImplementedError() + + def plot_score_and_probs(self): + clss = torch.cat(self.val_outputs['clss_noisycls']) + probs = torch.softmax(torch.cat(self.val_outputs['logits_noisycls']), dim=-1) + scores = torch.cat(self.val_outputs['scores_noisycls']).cpu().numpy() + score_norms = np.linalg.norm(scores, axis=-1) + alphas = torch.cat(self.val_outputs['alphas_noisycls']).cpu().numpy() + true_probs = probs[torch.arange(len(probs)), clss].cpu().numpy() + bins = np.linspace(min(alphas), 12, 20) + indices = np.digitize(alphas, bins) + bin_means = [np.mean(true_probs[indices == i]) for i in range(1, len(bins))] + bin_std = [np.std(true_probs[indices == i]) for i in range(1, len(bins))] + bin_centers = 0.5 * (bins[:-1] + bins[1:]) + + bin_pos_std = [np.std(true_probs[indices == i][true_probs[indices == i] > np.mean(true_probs[indices == i])]) for i in range(1, len(bins))] + bin_neg_std = [np.std(true_probs[indices == i][true_probs[indices == i] < np.mean(true_probs[indices == i])]) for i in range(1, len(bins))] + plot_data = pd.DataFrame({'Alphas': bin_centers, 'Means': bin_means, 'Std': bin_std, 'Pos_Std': bin_pos_std, 'Neg_Std': bin_neg_std}) + plt.figure(figsize=(10, 6)) + sns.lineplot(x='Alphas', y='Means', data=plot_data) + plt.fill_between(plot_data['Alphas'], plot_data['Means'] - plot_data['Neg_Std'], plot_data['Means'] + plot_data['Pos_Std'], alpha=0.3) + plt.xlabel('Binned alphas values') + plt.ylabel('Mean of predicted probs for true class') + fig = plt.gcf() + fig.canvas.draw() + pil_probs = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) + + plt.close() + bin_means = [np.mean(score_norms[indices == i]) for i in range(1, len(bins))] + bin_std = [np.std(score_norms[indices == i]) for i in range(1, len(bins))] + bin_pos_std = [np.std(score_norms[indices == i][score_norms[indices == i] > np.mean(score_norms[indices == i])]) for i in range(1, len(bins))] + bin_neg_std = [np.std(score_norms[indices == i][score_norms[indices == i] < np.mean(score_norms[indices == i])]) for i in range(1, len(bins))] + plot_data = pd.DataFrame({'Alphas': bin_centers, 'Means': bin_means, 'Std': bin_std, 'Pos_Std': bin_pos_std, 'Neg_Std': bin_neg_std}) + plt.figure(figsize=(10, 6)) + sns.lineplot(x='Alphas', y='Means', data=plot_data) + plt.fill_between(plot_data['Alphas'], plot_data['Means'] - plot_data['Neg_Std'], + plot_data['Means'] + plot_data['Pos_Std'], alpha=0.3) + plt.xlabel('Binned alphas values') + plt.ylabel('Mean of norm of the scores') + fig = plt.gcf() + fig.canvas.draw() + pil_score_norms = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) + return pil_probs, pil_score_norms + + def log_data_similarities(self, seq_pred): + similarities1 = seq_pred.cpu()[:, None, :].eq(self.toy_data.data_class1[None, :, :]) # batchsize, dataset_size, seq_len + similarities2 = seq_pred.cpu()[:, None, :].eq(self.toy_data.data_class2[None, :, :]) # batchsize, dataset_size, seq_len + similarities = seq_pred.cpu()[:, None, :].eq(torch.cat([self.toy_data.data_class2[None, :, :], self.toy_data.data_class1[None, :, :]],dim=1)) # batchsize, dataset_size, seq_len + self.lg('data1_sim', similarities1.float().mean(-1).max(-1)[0]) + self.lg('data2_sim', similarities2.float().mean(-1).max(-1)[0]) + self.lg('data_sim', similarities.float().mean(-1).max(-1)[0]) + self.lg('mean_data1_sim', similarities1.float().mean(-1).mean(-1)) + self.lg('mean_data2_sim', similarities2.float().mean(-1).mean(-1)) + self.lg('mean_data_sim', similarities.float().mean(-1).mean(-1)) diff --git a/modules/general_module.py b/modules/general_module.py new file mode 100644 index 0000000000000000000000000000000000000000..ef6c3036149bd9c1f3165b6e87a0b63a73477958 --- /dev/null +++ b/modules/general_module.py @@ -0,0 +1,118 @@ +import os + +import pandas as pd +import torch, time, wandb +from collections import defaultdict +import pytorch_lightning as pl +import numpy as np +import pdb +from utils.log import get_logger + +logger = get_logger(__name__) + + + + + +class GeneralModule(pl.LightningModule): + def __init__(self, args): + super().__init__() + self.save_hyperparameters() + self.args = args + + self.iter_step = -1 + self._log = defaultdict(list) + self.generator = np.random.default_rng() + self.last_log_time = time.time() + + + def try_print_log(self): + + step = self.iter_step if self.args.validate else self.trainer.global_step + if (step + 1) % self.args.print_freq == 0: + print(os.environ["MODEL_DIR"]) + log = self._log + log = {key: log[key] for key in log if "iter_" in key} + + log = self.gather_log(log, self.trainer.world_size) + mean_log = self.get_log_mean(log) + mean_log.update( + {'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) + if self.trainer.is_global_zero: + print(str(mean_log)) + self.log_dict(mean_log, batch_size=1) + if self.args.wandb: + wandb.log(mean_log) + for key in list(log.keys()): + if "iter_" in key: + del self._log[key] + + def lg(self, key, data): + if isinstance(data, torch.Tensor): + data = data.detach().cpu().item() + log = self._log + # pdb.set_trace() + if self.args.validate or self.stage == 'train': + log["iter_" + key].append(data) + log[self.stage + "_" + key].append(data) + + def on_train_epoch_end(self): + log = self._log + log = {key: log[key] for key in log if "train_" in key} + log = self.gather_log(log, self.trainer.world_size) + mean_log = self.get_log_mean(log) + mean_log.update( + {'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) + + if self.trainer.is_global_zero: + logger.info(str(mean_log)) + self.log_dict(mean_log, batch_size=1) + if self.args.wandb: + wandb.log(mean_log) + + for key in list(log.keys()): + if "train_" in key: + del self._log[key] + + def on_validation_epoch_end(self): + self.generator = np.random.default_rng() + log = self._log + log = {key: log[key] for key in log if "val_" in key} + log = self.gather_log(log, self.trainer.world_size) + mean_log = self.get_log_mean(log) + mean_log.update( + {'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) + + if self.trainer.is_global_zero: + logger.info(str(mean_log)) + self.log_dict(mean_log, batch_size=1) + if self.args.wandb: + wandb.log(mean_log) + + path = os.path.join( + os.environ["MODEL_DIR"], f"val_{self.trainer.global_step}.csv" + ) + pd.DataFrame(log).to_csv(path) + + for key in list(log.keys()): + if "val_" in key: + del self._log[key] + + + + def gather_log(self, log, world_size): + if world_size == 1: + return log + log_list = [None] * world_size + torch.distributed.all_gather_object(log_list, log) + log = {key: sum([l[key] for l in log_list], []) for key in log} + return log + + def get_log_mean(self, log): + out = {} + for key in log: + try: + out[key] = np.nanmean(log[key]) + except: + pass + return out \ No newline at end of file diff --git a/moo.py b/moo.py new file mode 100644 index 0000000000000000000000000000000000000000..cf5c24ea41b6d7f64f0514477ad2d936ad8be00d --- /dev/null +++ b/moo.py @@ -0,0 +1,91 @@ +import yaml +from tqdm import tqdm +import torch +from torch import nn +from transformers import AutoTokenizer + +from models.peptide_classifiers import * + +from utils.parsing import parse_guidance_args +args = parse_guidance_args() + +import pdb +import random +import inspect +import csv + +# MOO hyper-parameters +step_size = 1 / 100 +n_samples = 1 +vocab_size = 24 +source_distribution = "uniform" +device = 'cuda:0' + +length = args.length +target = args.target_protein +if args.motifs: + motifs = parse_motifs(args.motifs).to(device) + print(motifs) + +tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") +target_sequence = tokenizer(target, return_tensors='pt')['input_ids'].to(device) + +# Load Models +solver = load_solver('/scratch/pranamlab/tong/checkpoints/MOG-DFM/ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device) + +bindevaluator = load_bindevaluator('/scratch/pranamlab/tong/checkpoints/BindEvaluator/model_path/finetuned_BindEvaluator.ckpt', device) +motif_model = MotifModel(bindevaluator, target_sequence, motifs, penalty=args.motif_penalty) + +affinity_predictor = load_affinity_predictor('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/binding_affinity_unpooled.pt', device) +affinity_model = AffinityModel(affinity_predictor, target_sequence) +hemolysis_model = HemolysisModel(device=device) +nonfouling_model = NonfoulingModel(device=device) +solubility_model = SolubilityModelNew(device=device) +halflife_model = HalfLifeModel(device=device) + +score_models = [hemolysis_model, nonfouling_model, halflife_model, affinity_model, motif_model] + +for i in range(args.n_batches): + if source_distribution == "uniform": + x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device) # CHANGE! + elif source_distribution == "mask": + x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long() + else: + raise NotImplementedError + + zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device) + twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device) + x_init = torch.cat([zeros, x_init, twos], dim=1) + + x_1 = solver.multi_guidance_sample(args=args, x_init=x_init, + step_size=step_size, + verbose=True, + time_grid=torch.tensor([0.0, 1.0-1e-3]), + score_models=score_models, + num_objectives=len(score_models) + int(args.motif_penalty), + weights=args.weights) + + samples = x_1.tolist() + samples = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples] + print(samples) + + scores = [] + for i, s in enumerate(score_models): + sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) + if 't' in sig.parameters: + candidate_scores = s(x_1, 1) + else: + candidate_scores = s(x_1) + + if isinstance(candidate_scores, tuple): + for score in candidate_scores: + scores.append(score.item()) + else: + scores.append(candidate_scores.item()) + print(scores) + + with open(args.output_file, 'a') as f: + f.write(samples[0]) + for score in scores: + f.write(f",{score}") + f.write('\n') \ No newline at end of file diff --git a/moppit.py b/moppit.py new file mode 100644 index 0000000000000000000000000000000000000000..54c0b2f427d2f5846f30aef824fa2100e249d203 --- /dev/null +++ b/moppit.py @@ -0,0 +1,90 @@ +import yaml +from tqdm import tqdm +import torch +from torch import nn +from transformers import AutoTokenizer + +from models.peptide_classifiers import * + +from utils.parsing import parse_guidance_args +args = parse_guidance_args() + +import pdb +import random +import inspect + +# MOO hyper-parameters +step_size = 1 / 100 +n_samples = 1 +length = args.length +target = args.target_protein +motifs = args.motifs # args.motifs +vocab_size = 24 +source_distribution = "uniform" +device = 'cuda:0' + +tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") +target_sequence = tokenizer(target, return_tensors='pt')['input_ids'].to(device) +motifs = parse_motifs(motifs).to(device) +print(motifs) + +# Load Models +solver = load_solver('/scratch/pranamlab/tong/checkpoints/MOG-DFM/ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device) + +bindevaluator = load_bindevaluator('/scratch/pranamlab/tong/checkpoints/BindEvaluator/model_path/finetuned_BindEvaluator.ckpt', device) +motif_model = MotifModel(bindevaluator, target_sequence, motifs, penalty=True) + +affinity_predictor = load_affinity_predictor('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/binding_affinity_unpooled.pt', device) +affinity_model = AffinityModel(affinity_predictor, target_sequence) + +score_models = [motif_model, affinity_model] + +for i in range(args.n_batches): + if source_distribution == "uniform": + x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device) + elif source_distribution == "mask": + x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long() + else: + raise NotImplementedError + + zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device) + twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device) + x_init = torch.cat([zeros, x_init, twos], dim=1) + + x_1 = solver.multi_guidance_sample(args=args, x_init=x_init, + step_size=step_size, + verbose=True, + time_grid=torch.tensor([0.0, 1.0-1e-3]), + score_models=score_models, + num_objectives=3, + weights=args.weights) + + samples = x_1.tolist() + samples = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples] + print(samples) + + scores = [] + for i, s in enumerate(score_models): + sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) + if 't' in sig.parameters: + candidate_scores = s(x_1, 1) + else: + candidate_scores = s(x_1) + + if isinstance(candidate_scores, tuple): + for score in candidate_scores: + scores.append(score.item()) + else: + scores.append(candidate_scores.item()) + print(scores) + + with open(args.output_file, 'a') as f: + f.write(samples[0]) + for score in scores: + f.write(f",{score}") + f.write('\n') + # samples = x_1.tolist() + # sample = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples][0] + # with open(f"/vast/home/c/chentong/MOG-DFM/samples/{name}.csv", "a") as f: + # f.write(sample + ',' + str(score_list_0[-1]) + ',' + str(score_list_1[-1]) + '\n') + diff --git a/scripts/bindevaluator.sh b/scripts/bindevaluator.sh new file mode 100644 index 0000000000000000000000000000000000000000..3ab207ec951b47a96c436e9d71f24fb733c799ea --- /dev/null +++ b/scripts/bindevaluator.sh @@ -0,0 +1,12 @@ +export CUDA_VISIBLE_DEVICES=0 + +python -u bindevaluator.py \ +-target MHVPSGAQLGLRPDLLARRRLKRCPSRWLCLSAAWSFVQVFSEPDGFTVIFSGLGNNAGGTMHWNDTRPAHFRILKVVLREAVAECLMDSYSLDVHGGRRTAAG \ +-binder YVEICRCVVC \ +-sm /scratch/pranamlab/tong/checkpoints/BindEvaluator/model_path/finetuned_BindEvaluator.ckpt \ +-n_layers 8 \ +-d_model 128 \ +-d_hidden 128 \ +-n_head 8 \ +-d_inner 64 +# -motifs "16, 18, 19, 21, 22, 25, 27, 32, 34, 91, 114, 115, 129, 133, 206, 224, 225, 226, 228, 234, 235" \ No newline at end of file diff --git a/scripts/moo.sh b/scripts/moo.sh new file mode 100644 index 0000000000000000000000000000000000000000..72d72aa56665871429782439a6644e58027cb90a --- /dev/null +++ b/scripts/moo.sh @@ -0,0 +1,25 @@ +export CUDA_VISIBLE_DEVICES=7 + +python -u moo.py \ +--output_file '/scratch/pranamlab/tong/MOG-DFM/collaboration/FFX.csv' \ +--length 10 \ +--n_batches 600 \ +--weights 1 1 1 4 4 2 \ +--motifs '16-31,62-79' \ +--motif_penalty \ +--target_protein MHVPSGAQLGLRPDLLARRRLKRCPSRWLCLSAAWSFVQVFSEPDGFTVIFSGLGNNAGGTMHWNDTRPAHFRILKVVLREAVAECLMDSYSLDVHGGRRTAAG + +# python -u moo.py \ +# --output_file '/scratch/pranamlab/tong/MOG-DFM/collaboration/new_GFAP_365-374.csv' \ +# --length 10 \ +# --weights 1 1 1 1 1 8 8 \ +# --motifs '365-374' \ +# --target_protein MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM + +# python -u moo.py \ +# --output_file '/scratch/pranamlab/tong/MOG-DFM/collaboration/new_GFAP_76-79_365-374.csv' \ +# --length 10 \ +# --weights 1 1 1 1 1 8 8 \ +# --motifs '76-79, 365-374' \ +# --target_protein MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM + diff --git a/scripts/moppit.sh b/scripts/moppit.sh new file mode 100644 index 0000000000000000000000000000000000000000..e0a02408cd0a982c5ebe56448d044e45dd5cc05a --- /dev/null +++ b/scripts/moppit.sh @@ -0,0 +1,9 @@ +export CUDA_VISIBLE_DEVICES=7 + +python -u moppit.py \ +--output_file '/scratch/pranamlab/tong/MOG-DFM/collaboration/FFX.csv' \ +--n_batches 100 \ +--length 10 \ +--weights 1 1 2 \ +--motifs '16-31,62-79' \ +--target_protein MHVPSGAQLGLRPDLLARRRLKRCPSRWLCLSAAWSFVQVFSEPDGFTVIFSGLGNNAGGTMHWNDTRPAHFRILKVVLREAVAECLMDSYSLDVHGGRRTAAG \ No newline at end of file