File size: 8,504 Bytes
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import torch
import torch.nn.functional as F
import torchmetrics
from src.tools.utils import load_yaml_config
from tqdm import tqdm

class SequenceSampler():
    def __init__(self, num_sequences, sampling_temperature, sampling_type = 'primitive', bfactor_recovery_metric = 'pearson_R', pMPNN_model = None) -> None:
        self.num_sequences = num_sequences
        self.temperature = sampling_temperature
        self.primitively_sampled = None
        self.hard_sampled = None
        self.bfactor_predictor = None
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.pMPNN_model = pMPNN_model.to(self.device) if pMPNN_model else None
        

        if sampling_type == 'primitive':
            self.chosen_sampling = self.primitive_sampling
        elif sampling_type == 'pMPNN':
            if not pMPNN_model:
                raise ValueError('pMPNN model must be provided for pMPNN sampling')
            self.chosen_sampling = self.proteinMPNN_sampling
        else:
            raise NotImplementedError
        
        if bfactor_recovery_metric == 'pearson_R':
            self.bfactor_recovery_metric = torchmetrics.PearsonCorrCoef().to(self.device)
        else:
            raise NotImplementedError

        #For now, sample from the logits / softmax

        #TODO: implement ProteinMPNN sampling
    
    def softmax_with_temperature(self, logits):
        # Scale logits by the temperature
        scaled_logits = logits / self.temperature
        # Compute softmax
        probabilities = torch.softmax(scaled_logits, dim=-1)
        return probabilities

    def sample_from_logits(self, logits):
        probabilities = self.softmax_with_temperature(logits)
        # Sample from the categorical distribution based on the computed probabilities
        categorical = torch.distributions.Categorical(probabilities)
        return categorical.sample()

    def proteinMPNN_sampling(self, logits, batch = None):
        if not batch:
            raise ValueError('Batch featurized for ProteinMPNN must be provided for pMPNN sampling')
        
        with torch.no_grad():
            retVal = []
            for i in range(self.num_sequences):
                X, S, mask, chain_M, chain_M_pos, residue_idx, chain_encoding_all = batch['X'], batch['S'], batch['mask'], batch['chain_M'], batch['chain_M_pos'], batch['residue_idx'], batch['chain_encoding_all']
                randn = torch.randn(chain_M.shape, device=X.device)
                sampled_seq = self.pMPNN_model.sample(X=X, randn = randn, S_true = S, chain_mask = chain_M, chain_M_pos = chain_M_pos , chain_encoding_all = chain_encoding_all, residue_idx = residue_idx, mask=mask, temperature=self.temperature)
                # X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0
                retVal.append(sampled_seq['S'])
            return retVal

    def primitive_sampling(self, logits, batch = None): #leads to poor sequences ignoring the context of already decoded AAs
        retVal = []
        for i in range(self.num_sequences):
            retVal.append(self.sample_from_logits(logits))
        return retVal

    def load_bfactor_predictor(self, config_path = './src/models/configs/FlexibilityProtTrans.yaml'):
        print('Loading model based on the configs:', config_path)
        # print('Setting precision to medium')
        # torch.set_float32_matmul_precision('medium')
        from src.models.prottrans import FlexibilityProtTrans
        flex_params = load_yaml_config(config_path)
        # flex_params_dict = OmegaConf.to_container(flex_params, resolve=True)
        self.bfactor_predictor = FlexibilityProtTrans(**flex_params)

    
    def hard_sampling(self, logits):
        retVal = torch.argmax(logits, dim=-1)
        return retVal
    
    def eval_oracle_recovery(self, gt_seq, logits, mask, batch = None):
        hard_recovery, oracle_recovery = None, 0
        sampled_seqs = self.chosen_sampling(logits, batch=batch)
        hard_seq = self.hard_sampling(logits)

        hard_cmp = hard_seq==gt_seq
        hard_recovery = (hard_cmp*mask).sum()/(mask.sum())

        sampled_recoveries = [hard_recovery]
        for seq in sampled_seqs:
            oracle_cmp = seq==gt_seq
            _sampled_recovery = (oracle_cmp*mask).sum()/(mask.sum())
            sampled_recoveries.append(_sampled_recovery)

        oracle_recovery = max(sampled_recoveries)

        return hard_recovery, oracle_recovery
    
    def eval_bfactor_profile_recovery(self, gt_bfactors, gt_seq, logits, mask, batch = None):
        if not self.bfactor_predictor:
            self.load_bfactor_predictor()
        

        seq_recovery_by_bfactor_profile = 0

        sampled_seqs = self.chosen_sampling(logits, batch=batch)
        hard_seq = self.hard_sampling(logits)
        sampled_seqs.append(hard_seq)

        bfactor_recoveries = []
        bfact_to_seq = {}
        for seq in tqdm(sampled_seqs):
            ### TODO adapt below
            #New Dynamics-aware loss
            one_hot_seq = F.one_hot(seq, num_classes=33)
            flex_model_input = one_hot_seq.permute(0, 2, 1).float().to(self.device)
            pred_bfactors = self.bfactor_predictor(flex_model_input)['predicted_normalized_bfactors'][:,:-1,0]

            _filter_nans_mask = ~torch.isnan(pred_bfactors) #torch.where(~torch.isnan(flex_loss))
            _flex_mask = mask*_filter_nans_mask
            _flex_mask = _flex_mask.int()
            bfactor_recovery = self.bfactor_recovery_metric(pred_bfactors[torch.where(_flex_mask)], gt_bfactors[torch.where(_flex_mask)])
            bfact_to_seq[bfactor_recovery] = seq
            bfactor_recoveries.append(bfactor_recovery)

        seq_selected_by_bfact = bfact_to_seq[max(bfactor_recoveries)]
        # import pdb; pdb.set_trace()
        _cmp = seq_selected_by_bfact==gt_seq
        seq_recovery_by_bfactor_profile = (_cmp*mask).sum()/(mask.sum())

        return seq_recovery_by_bfactor_profile

    def eval_multiple_predictions_oracle(self, predictions):
        hard_recovery, oracle_recovery = [], []
        print('Evaluating sequence recovery by oracle...')
        for pred in tqdm(predictions):
            _hr, _or = self.eval_oracle_recovery(logits = pred['log_probs'], mask = pred['mask'], gt_seq = pred['original_sequence'], batch=pred['batch'])
            hard_recovery.append(_hr)
            oracle_recovery.append(_or)
        return hard_recovery, oracle_recovery
    

class SamplerGrid():
    def __init__(self, sample_sizes: list, sampling_temperatures: list, sampling_type = 'primitive', pMPNN_model = None) -> None:
        self.samplers = [SequenceSampler(s,t, sampling_type, pMPNN_model=pMPNN_model) for t in sampling_temperatures for s in sample_sizes]
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    def get_optimal_sampler(self, predictions):
        oracle_recoveries = []
        for sampler in self.samplers:
            hard_recovery, oracle_recovery = sampler.eval_multiple_predictions_oracle(predictions) #here
            avg_hard_recovery = sum(hard_recovery)/len(hard_recovery)
            avg_oracle_recovery = sum(oracle_recovery)/len(oracle_recovery)
            oracle_recoveries.append(avg_oracle_recovery)
            print(f'T = {sampler.temperature}, Sample_size = {sampler.num_sequences}, Average hard recovery: {avg_hard_recovery}, average oracle recovery: {avg_oracle_recovery}')
    
        best_sampler = self.samplers[oracle_recoveries.index(max(oracle_recoveries))]
        print(f'Best sampler: T = {best_sampler.temperature}, Sample_size = {best_sampler.num_sequences}, with oracle recovery: {max(oracle_recoveries)}')
        return best_sampler
    
    def eval_bfactor_selection(self, predictions, sampler = None):
        if not sampler:
            sampler = self.get_optimal_sampler(predictions)
        recoveries = []
        print('Evaluating sequence recovery by bfactor profile...')
        for pred in tqdm(predictions):
            seq_recovery_by_bfactor_profile = sampler.eval_bfactor_profile_recovery(logits = pred['log_probs'], mask = pred['mask'], gt_seq = pred['original_sequence'], gt_bfactors = pred['gt_bfactors'], batch = pred['batch']) #
            recoveries.append(seq_recovery_by_bfactor_profile)
        avg_recovery = sum(recoveries)/len(recoveries)
        print(f'Average sequence recovery by bfactor profile: {avg_recovery}')
        return avg_recovery