Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,504 Bytes
7968cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import torch
import torch.nn.functional as F
import torchmetrics
from src.tools.utils import load_yaml_config
from tqdm import tqdm
class SequenceSampler():
def __init__(self, num_sequences, sampling_temperature, sampling_type = 'primitive', bfactor_recovery_metric = 'pearson_R', pMPNN_model = None) -> None:
self.num_sequences = num_sequences
self.temperature = sampling_temperature
self.primitively_sampled = None
self.hard_sampled = None
self.bfactor_predictor = None
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.pMPNN_model = pMPNN_model.to(self.device) if pMPNN_model else None
if sampling_type == 'primitive':
self.chosen_sampling = self.primitive_sampling
elif sampling_type == 'pMPNN':
if not pMPNN_model:
raise ValueError('pMPNN model must be provided for pMPNN sampling')
self.chosen_sampling = self.proteinMPNN_sampling
else:
raise NotImplementedError
if bfactor_recovery_metric == 'pearson_R':
self.bfactor_recovery_metric = torchmetrics.PearsonCorrCoef().to(self.device)
else:
raise NotImplementedError
#For now, sample from the logits / softmax
#TODO: implement ProteinMPNN sampling
def softmax_with_temperature(self, logits):
# Scale logits by the temperature
scaled_logits = logits / self.temperature
# Compute softmax
probabilities = torch.softmax(scaled_logits, dim=-1)
return probabilities
def sample_from_logits(self, logits):
probabilities = self.softmax_with_temperature(logits)
# Sample from the categorical distribution based on the computed probabilities
categorical = torch.distributions.Categorical(probabilities)
return categorical.sample()
def proteinMPNN_sampling(self, logits, batch = None):
if not batch:
raise ValueError('Batch featurized for ProteinMPNN must be provided for pMPNN sampling')
with torch.no_grad():
retVal = []
for i in range(self.num_sequences):
X, S, mask, chain_M, chain_M_pos, residue_idx, chain_encoding_all = batch['X'], batch['S'], batch['mask'], batch['chain_M'], batch['chain_M_pos'], batch['residue_idx'], batch['chain_encoding_all']
randn = torch.randn(chain_M.shape, device=X.device)
sampled_seq = self.pMPNN_model.sample(X=X, randn = randn, S_true = S, chain_mask = chain_M, chain_M_pos = chain_M_pos , chain_encoding_all = chain_encoding_all, residue_idx = residue_idx, mask=mask, temperature=self.temperature)
# X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0
retVal.append(sampled_seq['S'])
return retVal
def primitive_sampling(self, logits, batch = None): #leads to poor sequences ignoring the context of already decoded AAs
retVal = []
for i in range(self.num_sequences):
retVal.append(self.sample_from_logits(logits))
return retVal
def load_bfactor_predictor(self, config_path = './src/models/configs/FlexibilityProtTrans.yaml'):
print('Loading model based on the configs:', config_path)
# print('Setting precision to medium')
# torch.set_float32_matmul_precision('medium')
from src.models.prottrans import FlexibilityProtTrans
flex_params = load_yaml_config(config_path)
# flex_params_dict = OmegaConf.to_container(flex_params, resolve=True)
self.bfactor_predictor = FlexibilityProtTrans(**flex_params)
def hard_sampling(self, logits):
retVal = torch.argmax(logits, dim=-1)
return retVal
def eval_oracle_recovery(self, gt_seq, logits, mask, batch = None):
hard_recovery, oracle_recovery = None, 0
sampled_seqs = self.chosen_sampling(logits, batch=batch)
hard_seq = self.hard_sampling(logits)
hard_cmp = hard_seq==gt_seq
hard_recovery = (hard_cmp*mask).sum()/(mask.sum())
sampled_recoveries = [hard_recovery]
for seq in sampled_seqs:
oracle_cmp = seq==gt_seq
_sampled_recovery = (oracle_cmp*mask).sum()/(mask.sum())
sampled_recoveries.append(_sampled_recovery)
oracle_recovery = max(sampled_recoveries)
return hard_recovery, oracle_recovery
def eval_bfactor_profile_recovery(self, gt_bfactors, gt_seq, logits, mask, batch = None):
if not self.bfactor_predictor:
self.load_bfactor_predictor()
seq_recovery_by_bfactor_profile = 0
sampled_seqs = self.chosen_sampling(logits, batch=batch)
hard_seq = self.hard_sampling(logits)
sampled_seqs.append(hard_seq)
bfactor_recoveries = []
bfact_to_seq = {}
for seq in tqdm(sampled_seqs):
### TODO adapt below
#New Dynamics-aware loss
one_hot_seq = F.one_hot(seq, num_classes=33)
flex_model_input = one_hot_seq.permute(0, 2, 1).float().to(self.device)
pred_bfactors = self.bfactor_predictor(flex_model_input)['predicted_normalized_bfactors'][:,:-1,0]
_filter_nans_mask = ~torch.isnan(pred_bfactors) #torch.where(~torch.isnan(flex_loss))
_flex_mask = mask*_filter_nans_mask
_flex_mask = _flex_mask.int()
bfactor_recovery = self.bfactor_recovery_metric(pred_bfactors[torch.where(_flex_mask)], gt_bfactors[torch.where(_flex_mask)])
bfact_to_seq[bfactor_recovery] = seq
bfactor_recoveries.append(bfactor_recovery)
seq_selected_by_bfact = bfact_to_seq[max(bfactor_recoveries)]
# import pdb; pdb.set_trace()
_cmp = seq_selected_by_bfact==gt_seq
seq_recovery_by_bfactor_profile = (_cmp*mask).sum()/(mask.sum())
return seq_recovery_by_bfactor_profile
def eval_multiple_predictions_oracle(self, predictions):
hard_recovery, oracle_recovery = [], []
print('Evaluating sequence recovery by oracle...')
for pred in tqdm(predictions):
_hr, _or = self.eval_oracle_recovery(logits = pred['log_probs'], mask = pred['mask'], gt_seq = pred['original_sequence'], batch=pred['batch'])
hard_recovery.append(_hr)
oracle_recovery.append(_or)
return hard_recovery, oracle_recovery
class SamplerGrid():
def __init__(self, sample_sizes: list, sampling_temperatures: list, sampling_type = 'primitive', pMPNN_model = None) -> None:
self.samplers = [SequenceSampler(s,t, sampling_type, pMPNN_model=pMPNN_model) for t in sampling_temperatures for s in sample_sizes]
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def get_optimal_sampler(self, predictions):
oracle_recoveries = []
for sampler in self.samplers:
hard_recovery, oracle_recovery = sampler.eval_multiple_predictions_oracle(predictions) #here
avg_hard_recovery = sum(hard_recovery)/len(hard_recovery)
avg_oracle_recovery = sum(oracle_recovery)/len(oracle_recovery)
oracle_recoveries.append(avg_oracle_recovery)
print(f'T = {sampler.temperature}, Sample_size = {sampler.num_sequences}, Average hard recovery: {avg_hard_recovery}, average oracle recovery: {avg_oracle_recovery}')
best_sampler = self.samplers[oracle_recoveries.index(max(oracle_recoveries))]
print(f'Best sampler: T = {best_sampler.temperature}, Sample_size = {best_sampler.num_sequences}, with oracle recovery: {max(oracle_recoveries)}')
return best_sampler
def eval_bfactor_selection(self, predictions, sampler = None):
if not sampler:
sampler = self.get_optimal_sampler(predictions)
recoveries = []
print('Evaluating sequence recovery by bfactor profile...')
for pred in tqdm(predictions):
seq_recovery_by_bfactor_profile = sampler.eval_bfactor_profile_recovery(logits = pred['log_probs'], mask = pred['mask'], gt_seq = pred['original_sequence'], gt_bfactors = pred['gt_bfactors'], batch = pred['batch']) #
recoveries.append(seq_recovery_by_bfactor_profile)
avg_recovery = sum(recoveries)/len(recoveries)
print(f'Average sequence recovery by bfactor profile: {avg_recovery}')
return avg_recovery |