import sys; sys.path.append('/huyuqi/xmyu/DiffSDS') import inspect import torch from src.tools.utils import cuda import torch.nn as nn import os from torcheval.metrics.text import Perplexity from src.interface.model_interface import MInterface_base import math import torch.nn.functional as F from omegaconf import OmegaConf from src.tools.utils import load_yaml_config import torchmetrics class MInterface(MInterface_base): def __init__(self, model_name=None, loss=None, lr=None, **kwargs): super().__init__() self.save_hyperparameters() self.load_model() self.use_dynamics = kwargs.get('use_dynamics', 0) self.flex_loss_coeff = torch.Tensor([kwargs.get('flex_loss_coeff', 0)]).to('cuda:0').to(torch.float) self.flex_loss_coeff.requires_grad = False if self.use_dynamics: self.load_flex_predictor() self.flex_loss_type = kwargs.get('loss_fn', 0) if self.flex_loss_type == 'MSE': self.flex_loss_fn = nn.MSELoss(reduction='none') elif self.flex_loss_type == 'L1': self.flex_loss_fn = nn.L1Loss(reduction='none') elif self.flex_loss_type == 'DPO': self.flex_loss_fn = ... else: raise ValueError(f"Not recognized type of loss function {self.flex_loss_type}") self.cross_entropy = nn.NLLLoss(reduction='none') os.makedirs(os.path.join(self.hparams.res_dir, self.hparams.ex_name), exist_ok=True) self.control_sum_recovery = 0 self.control_sum_batch_sizes = 0 self.grad_normalization = kwargs.get('grad_normalization', 0) self.use_pmpnn_checkpoint = kwargs.get('use_pmpnn_checkpoint',0) if self.use_pmpnn_checkpoint: print('Loading pmpnn checkpoint from {}'.format(self.model.pmpnn_init_weights_path)) state_dict = torch.load(self.model.pmpnn_init_weights_path)['state_dict'] #['module'] state_dict = {key: value for key, value in state_dict.items() if 'model.' in key[:6]} state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()} self.model.load_state_dict(state_dict) self.MSE = nn.MSELoss(reduction='none') self.automatic_optimization = False if self.hparams.use_dynamics: self.pearson = torchmetrics.PearsonCorrCoef() self.spearman = torchmetrics.SpearmanCorrCoef() self.validation_step_outputs = [] self.test_step_outputs = [] #### setting forward hook # def forward_hook(module, input, output): # def check_nan(tensor): # if isinstance(tensor, torch.Tensor): # if torch.isnan(tensor).any(): # print(f"NaN detected in the output of {type(module).__name__}") # print(f"Tensor shape: {tensor.shape}") # print(f"Tensor stats: mean={tensor.mean()}, std={tensor.std()}, min={tensor.min()}, max={tensor.max()}, all={torch.isnan(tensor).all()}") # elif isinstance(tensor, tuple): # for i, t in enumerate(tensor): # if isinstance(t, torch.Tensor): # if torch.isnan(t).any(): # print(f"NaN detected in the output[{i}] of {type(module).__name__}") # print(f"Tensor shape: {t.shape}") # print(f"Tensor stats: mean={t.mean()}, std={t.std()}, min={t.min()}, max={t.max()}, all={torch.isnan(tensor).all()}") # if isinstance(output, tuple): # for i, out in enumerate(output): # check_nan(out) # else: # check_nan(output) # for name, module in self.model.named_modules(): # module.register_forward_hook(forward_hook) # for name, module in self.flex_model.named_modules(): # module.register_forward_hook(forward_hook) #### def forward(self, batch, mode='train', temperature=1.0): if self.hparams.augment_eps>0: batch['X'] = batch['X'] + self.hparams.augment_eps * torch.randn_like(batch['X']) batch = self.model._get_features(batch) results = self.model(batch) log_probs, mask = results['log_probs'], batch['mask'] if len(log_probs.shape) == 3: if self.hparams.use_dynamics: loss = self.combined_flex_aware_loss(batch, pred_log_probs=log_probs) #loss = loss_dict['combined_loss'] else: loss = self.cross_entropy(log_probs.permute(0,2,1), batch['S']) loss = (loss*mask).sum()/(mask.sum()) elif len(log_probs.shape) == 2: if self.hparams.model_name == 'GVP': loss = self.cross_entropy(log_probs, batch.seq) else: loss = self.cross_entropy(log_probs, batch['S']) if self.hparams.model_name == 'AlphaDesign': loss += self.cross_entropy(results['log_probs0'], batch['S']) loss = (loss*mask).sum()/(mask.sum()) cmp = log_probs.argmax(dim=-1)==batch['S'] recovery = (cmp*mask).sum()/(mask.sum()) if mode == 'predict': return {'original_sequence':batch['S'],'correct_positions': cmp, 'mask':mask,'loss':loss, 'recovery':recovery, 'title':batch['title'], 'log_probs': log_probs, 'batch':batch} #, 'gt_bfactors': batch['norm_bfactors'], 'batch':batch} elif mode == 'eval': return {'original_sequence':batch['S'],'correct_positions': cmp, 'mask':mask,'loss':loss, 'recovery':recovery, 'title':batch['title'], 'log_probs': log_probs, 'batch':batch} else: return loss, recovery def avgCorrelations(self, preds, gts, masks): pearson_R = 0 spearman_R = 0 valid_datapoints = 0 for pred, gt, mask in zip(preds, gts, masks): dpR = self.pearson(pred[torch.where(mask)], gt[torch.where(mask)]) if torch.isnan(dpR): continue else: pearson_R += dpR spearman_R += self.spearman(pred[torch.where(mask)], gt[torch.where(mask)]) valid_datapoints += 1 return pearson_R/valid_datapoints, spearman_R/valid_datapoints def temperature_schedular(self, batch_idx): total_steps = self.hparams.steps_per_epoch*self.hparams.epoch initial_lr = 1.0 circle_steps = total_steps//100 x = batch_idx / total_steps threshold = 0.48 if x