from typing import Any import torch import time import os import random import wandb import numpy as np import pandas as pd import logging from pytorch_lightning import LightningModule from utils.experiments import write_prot_to_pdb, save_traj from utils import metrics from models.proteinflow import ProteinFlow from utils import all_atom from utils import so3Utils as su from utils import residue_constants as rc from utils.flows import Interpolant from utils.modelUtils import t_stratified_loss, to_numpy from pytorch_lightning.loggers.wandb import WandbLogger class ProteinFlowModule(LightningModule): def __init__(self, cfg, folding_cfg=None): super().__init__() self._print_logger = logging.getLogger(__name__) self._exp_cfg = cfg.experiment self._model_cfg = cfg.model self._data_cfg = cfg.data self._interpolant_cfg = cfg.interpolant # Set-up vector field prediction model self.model = ProteinFlow(cfg.model) # Set-up interpolant self.interpolant = Interpolant(cfg.interpolant) self._sample_write_dir = self._exp_cfg.checkpointer.dirpath os.makedirs(self._sample_write_dir, exist_ok=True) self.validation_epoch_metrics = [] self.validation_epoch_samples = [] self.save_hyperparameters() def on_train_start(self): self._epoch_start_time = time.time() def on_train_epoch_end(self): epoch_time = (time.time() - self._epoch_start_time) / 60.0 self.log( 'train/epoch_time_minutes', epoch_time, on_step=False, on_epoch=True, prog_bar=False ) self._epoch_start_time = time.time() def model_step(self, noisy_batch: Any): training_cfg = self._exp_cfg.training loss_mask = noisy_batch['res_mask'] if training_cfg.min_plddt_mask is not None: plddt_mask = noisy_batch['res_plddt'] > training_cfg.min_plddt_mask loss_mask *= plddt_mask num_batch, num_res = loss_mask.shape # Ground truth labels gt_trans_1 = noisy_batch['trans_1'] gt_rotmats_1 = noisy_batch['rotmats_1'] rotmats_t = noisy_batch['rotmats_t'] gt_rot_vf = su.calc_rot_vf( rotmats_t, gt_rotmats_1.type(torch.float32)) gt_bb_atoms = all_atom.to_atom37(gt_trans_1, gt_rotmats_1)[:, :, :3] # Timestep used for normalization. t = noisy_batch['t'] norm_scale = 1 - torch.min( t[..., None], torch.tensor(training_cfg.t_normalize_clip)) # Model output predictions. model_output = self.model(noisy_batch) pred_trans_1 = model_output['pred_trans'] pred_rotmats_1 = model_output['pred_rotmats'] pred_rots_vf = su.calc_rot_vf(rotmats_t, pred_rotmats_1) # Backbone atom loss pred_bb_atoms = all_atom.to_atom37(pred_trans_1, pred_rotmats_1)[:, :, :3] gt_bb_atoms *= training_cfg.bb_atom_scale / norm_scale[..., None] pred_bb_atoms *= training_cfg.bb_atom_scale / norm_scale[..., None] loss_denom = torch.sum(loss_mask, dim=-1) * 3 bb_atom_loss = torch.sum( (gt_bb_atoms - pred_bb_atoms) ** 2 * loss_mask[..., None, None], dim=(-1, -2, -3) ) / loss_denom # Translation VF loss trans_error = (gt_trans_1 - pred_trans_1) / norm_scale * training_cfg.trans_scale trans_loss = training_cfg.translation_loss_weight * torch.sum( trans_error ** 2 * loss_mask[..., None], dim=(-1, -2) ) / loss_denom # Rotation VF loss rots_vf_error = (gt_rot_vf - pred_rots_vf) / norm_scale rots_vf_loss = training_cfg.rotation_loss_weights * torch.sum( rots_vf_error ** 2 * loss_mask[..., None], dim=(-1, -2) ) / loss_denom # Pairwise distance loss gt_flat_atoms = gt_bb_atoms.reshape([num_batch, num_res * 3, 3]) gt_pair_dists = torch.linalg.norm( gt_flat_atoms[:, :, None, :] - gt_flat_atoms[:, None, :, :], dim=-1) pred_flat_atoms = pred_bb_atoms.reshape([num_batch, num_res * 3, 3]) pred_pair_dists = torch.linalg.norm( pred_flat_atoms[:, :, None, :] - pred_flat_atoms[:, None, :, :], dim=-1) flat_loss_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3)) flat_loss_mask = flat_loss_mask.reshape([num_batch, num_res * 3]) flat_res_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3)) flat_res_mask = flat_res_mask.reshape([num_batch, num_res * 3]) gt_pair_dists = gt_pair_dists * flat_loss_mask[..., None] pred_pair_dists = pred_pair_dists * flat_loss_mask[..., None] pair_dist_mask = flat_loss_mask[..., None] * flat_res_mask[:, None, :] dist_mat_loss = torch.sum( (gt_pair_dists - pred_pair_dists) ** 2 * pair_dist_mask, dim=(1, 2)) dist_mat_loss /= (torch.sum(pair_dist_mask, dim=(1, 2)) - num_res) se3_vf_loss = trans_loss + rots_vf_loss auxiliary_loss = (bb_atom_loss + dist_mat_loss) * ( t[:, 0] > training_cfg.aux_loss_t_pass ) auxiliary_loss *= self._exp_cfg.training.aux_loss_weight se3_vf_loss += auxiliary_loss if torch.isnan(se3_vf_loss).any(): raise ValueError('NaN loss encountered') return { "bb_atom_loss": bb_atom_loss, "trans_loss": trans_loss, "dist_mat_loss": dist_mat_loss, "auxiliary_loss": auxiliary_loss, "rots_vf_loss": rots_vf_loss, "se3_vf_loss": se3_vf_loss } def validation_step(self, batch: Any, batch_idx: int): res_mask = batch['res_mask'] self.interpolant.set_device(res_mask.device) num_batch, num_res = res_mask.shape samples = self.interpolant.sample( num_batch, num_res, self.model, )[0][-1].numpy() batch_metrics = [] for i in range(num_batch): # Write out sample to PDB file final_pos = samples[i] saved_path = write_prot_to_pdb( final_pos, os.path.join( self._sample_write_dir, f'sample_{i}_idx_{batch_idx}_len_{num_res}.pdb'), no_indexing=True ) if isinstance(self.logger, WandbLogger): self.validation_epoch_samples.append( [saved_path, self.global_step, wandb.Molecule(saved_path)] ) mdtraj_metrics = metrics.calc_mdtraj_metrics(saved_path) ca_idx = rc.atom_order['CA'] ca_ca_metrics = metrics.calc_ca_ca_metrics(final_pos[:, ca_idx]) batch_metrics.append((mdtraj_metrics | ca_ca_metrics)) batch_metrics = pd.DataFrame(batch_metrics) self.validation_epoch_metrics.append(batch_metrics) def on_validation_epoch_end(self): if len(self.validation_epoch_samples) > 0: self.logger.log_table( key='valid/samples', columns=["sample_path", "global_step", "Protein"], data=self.validation_epoch_samples) self.validation_epoch_samples.clear() val_epoch_metrics = pd.concat(self.validation_epoch_metrics) for metric_name, metric_val in val_epoch_metrics.mean().to_dict().items(): self._log_scalar( f'valid/{metric_name}', metric_val, on_step=False, on_epoch=True, prog_bar=False, batch_size=len(val_epoch_metrics), ) self.validation_epoch_metrics.clear() def _log_scalar( self, key, value, on_step=True, on_epoch=False, prog_bar=True, batch_size=None, sync_dist=False, rank_zero_only=True ): if sync_dist and rank_zero_only: raise ValueError('Unable to sync dist when rank_zero_only=True') self.log( key, value, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar, batch_size=batch_size, sync_dist=sync_dist, rank_zero_only=rank_zero_only ) def training_step(self, batch: Any, stage: int): step_start_time = time.time() self.interpolant.set_device(batch['res_mask'].device) noisy_batch = self.interpolant.corrupt_batch(batch) if self._interpolant_cfg.self_condition and random.random() > 0.5: with torch.no_grad(): model_sc = self.model(noisy_batch) noisy_batch['trans_sc'] = model_sc['pred_trans'] batch_losses = self.model_step(noisy_batch) num_batch = batch_losses['bb_atom_loss'].shape[0] total_losses = { k: torch.mean(v) for k, v in batch_losses.items() } for k, v in total_losses.items(): self._log_scalar( f"train/{k}", v, prog_bar=False, batch_size=num_batch) # Losses to track. Stratified across t. t = torch.squeeze(noisy_batch['t']) self._log_scalar( "train/t", np.mean(to_numpy(t)), prog_bar=False, batch_size=num_batch) for loss_name, loss_dict in batch_losses.items(): stratified_losses = t_stratified_loss( t, loss_dict, loss_name=loss_name) for k, v in stratified_losses.items(): self._log_scalar( f"train/{k}", v, prog_bar=False, batch_size=num_batch) # Training throughput self._log_scalar( "train/length", batch['res_mask'].shape[1], prog_bar=False, batch_size=num_batch) self._log_scalar( "train/batch_size", num_batch, prog_bar=False) step_time = time.time() - step_start_time self._log_scalar( "train/examples_per_second", num_batch / step_time) train_loss = ( total_losses[self._exp_cfg.training.loss] + total_losses['auxiliary_loss'] ) self._log_scalar( "train/loss", train_loss, batch_size=num_batch) return train_loss def configure_optimizers(self): return torch.optim.AdamW( params=self.model.parameters(), **self._exp_cfg.optimizer ) def predict_step(self, batch, batch_idx): device = f'cuda:{torch.cuda.current_device()}' interpolant = Interpolant(self._infer_cfg.interpolant) interpolant.set_device(device) sample_length = batch['num_res'].item() diffuse_mask = torch.ones(1, sample_length) sample_id = batch['sample_id'].item() sample_dir = os.path.join( self._output_dir, f'length_{sample_length}', f'sample_{sample_id}') top_sample_csv_path = os.path.join(sample_dir, 'top_sample.csv') if os.path.exists(top_sample_csv_path): self._print_logger.info( f'Skipping instance {sample_id} length {sample_length}') return atom37_traj, model_traj, _ = interpolant.sample( 1, sample_length, self.model ) os.makedirs(sample_dir, exist_ok=True) bb_traj = to_numpy(torch.concat(atom37_traj, dim=0)) _ = save_traj( bb_traj[-1], bb_traj, np.flip(to_numpy(torch.concat(model_traj, dim=0)), axis=0), to_numpy(diffuse_mask)[0], output_dir=sample_dir, )