""" Metrics. """ import mdtraj as md import numpy as np import logging import torch import tree from utils import new_pdbUtils as du from utils import experiments as eu from openfold_np import residue_constants from tmtools import tm_align CA_IDX = residue_constants.atom_order['CA'] INTER_VIOLATION_METRICS = [ "bonds_c_n_loss_mean", "angles_ca_c_n_loss_mean", "clashes_mean_loss", ] SHAPE_METRICS = [ "coil_percent", "helix_percent", "strand_percent", "radius_of_gyration", ] CA_VIOLATION_METRICS = [ "ca_ca_bond_dev", "ca_ca_valid_percent", "ca_steric_clash_percent", "num_ca_steric_clashes", ] EVAL_METRICS = [ "tm_score", ] ALL_METRICS = ( INTER_VIOLATION_METRICS + SHAPE_METRICS + CA_VIOLATION_METRICS + EVAL_METRICS ) def calc_tm_score(pos_1, pos_2, seq_1, seq_2): tm_results = tm_align(pos_1, pos_2, seq_1, seq_2) return tm_results.tm_norm_chain1, tm_results.tm_norm_chain2 def calc_perplexity(pred, labels, mask): one_hot_labels = np.eye(pred.shape[-1])[labels] true_probs = np.sum(pred * one_hot_labels, axis=-1) ce = -np.log(true_probs + 1e-5) per_res_perplexity = np.exp(ce) return np.sum(per_res_perplexity * mask) / np.sum(mask) def calc_mdtraj_metrics(pdb_path): try: traj = md.load(pdb_path) pdb_ss = md.compute_dssp(traj, simplified=True) pdb_coil_percent = np.mean(pdb_ss == 'C') pdb_helix_percent = np.mean(pdb_ss == 'H') pdb_strand_percent = np.mean(pdb_ss == 'E') pdb_ss_percent = pdb_helix_percent + pdb_strand_percent pdb_rg = md.compute_rg(traj)[0] except IndexError as e: print('Error in calc_mdtraj_metrics: {}'.format(e)) pdb_ss_percent = 0.0 pdb_coil_percent = 0.0 pdb_helix_percent = 0.0 pdb_strand_percent = 0.0 pdb_rg = 0.0 return { 'non_coil_percent': pdb_ss_percent, 'coil_percent': pdb_coil_percent, 'helix_percent': pdb_helix_percent, 'strand_percent': pdb_strand_percent, 'radius_of_gyration': pdb_rg, } def calc_aligned_rmsd(pos_1, pos_2): aligned_pos_1 = du.rigid_transform_3D(pos_1, pos_2)[0] return np.mean(np.linalg.norm(aligned_pos_1 - pos_2, axis=-1)) def protein_metrics( *, pdb_path, atom37_pos, gt_atom37_pos, gt_aatype, flow_mask, ): # SS percantage mdtraj_metrics = calc_mdtraj_metrics(pdb_path) atom37_mask = np.any(atom37_pos, axis=-1) atom37_diffuse_mask = flow_mask[..., None] * atom37_mask prot = eu.create_full_prot(atom37_pos, atom37_diffuse_mask) violation_metrics = amber_minimize.get_violation_metrics(prot) struct_violations = violation_metrics["structural_violations"] inter_violations = struct_violations["between_residues"] # Geometry bb_mask = np.any(atom37_mask, axis=-1) ca_pos = atom37_pos[..., CA_IDX, :][bb_mask.astype(bool)] ca_ca_bond_dev, ca_ca_valid_percent = ca_ca_distance(ca_pos) num_ca_steric_clashes, ca_steric_clash_percent = ca_ca_clashes(ca_pos) # Eval bb_diffuse_mask = (flow_mask * bb_mask).astype(bool) unpad_gt_scaffold_pos = gt_atom37_pos[..., CA_IDX, :][bb_diffuse_mask] unpad_pred_scaffold_pos = atom37_pos[..., CA_IDX, :][bb_diffuse_mask] seq = du.aatype_to_seq(gt_aatype[bb_diffuse_mask]) _, tm_score = calc_tm_score( unpad_pred_scaffold_pos, unpad_gt_scaffold_pos, seq, seq ) metrics_dict = { "ca_ca_bond_dev": ca_ca_bond_dev, "ca_ca_valid_percent": ca_ca_valid_percent, "ca_steric_clash_percent": ca_steric_clash_percent, "num_ca_steric_clashes": num_ca_steric_clashes, "tm_score": tm_score, **mdtraj_metrics, } for k in INTER_VIOLATION_METRICS: metrics_dict[k] = inter_violations[k] metrics_dict = tree.map_structure(lambda x: np.mean(x).item(), metrics_dict) return metrics_dict def ca_ca_distance(ca_pos, tol=0.1): ca_bond_dists = np.linalg.norm(ca_pos - np.roll(ca_pos, 1, axis=0), axis=-1)[1:] ca_ca_dev = np.mean(np.abs(ca_bond_dists - residue_constants.ca_ca)) ca_ca_valid = np.mean(ca_bond_dists < (residue_constants.ca_ca + tol)) return ca_ca_dev, ca_ca_valid def ca_ca_clashes(ca_pos, tol=1.5): ca_ca_dists2d = np.linalg.norm(ca_pos[:, None, :] - ca_pos[None, :, :], axis=-1) inter_dists = ca_ca_dists2d[np.where(np.triu(ca_ca_dists2d, k=0) > 0)] clashes = inter_dists < tol return np.sum(clashes), np.mean(clashes) def calc_ca_ca_metrics(ca_pos, bond_tol=0.1, clash_tol=1.0): """Calculate CA-CA distance metrics. Args: ca_pos: [N, 3] array of CA positions bond_tol: Tolerance for CA-CA bond length deviation clash_tol: Distance threshold for steric clashes Returns: Dictionary of metrics """ # Debug logging logger = logging.getLogger(__name__) logger.info(f"Input ca_pos shape: {ca_pos.shape}") logger.info(f"Input ca_pos type: {type(ca_pos)}") # Ensure input is numpy array if isinstance(ca_pos, torch.Tensor): ca_pos = ca_pos.detach().cpu().numpy() # Ensure shape is [N, 3] if len(ca_pos.shape) == 1: ca_pos = ca_pos.reshape(-1, 3) elif len(ca_pos.shape) > 2: raise ValueError(f"Expected ca_pos shape [N, 3], got {ca_pos.shape}") logger.info(f"Processed ca_pos shape: {ca_pos.shape}") # Calculate CA-CA distances ca_bond_dists = np.linalg.norm( ca_pos - np.roll(ca_pos, 1, axis=0), axis=-1)[1:] ca_ca_dev = np.mean(np.abs(ca_bond_dists - residue_constants.ca_ca)) ca_ca_valid = np.mean(ca_bond_dists < (residue_constants.ca_ca + bond_tol)) # Calculate steric clashes ca_ca_dists2d = np.linalg.norm( ca_pos[:, None, :] - ca_pos[None, :, :], axis=-1) inter_dists = ca_ca_dists2d[np.where(np.triu(ca_ca_dists2d, k=0) > 0)] clashes = inter_dists < clash_tol return { 'ca_ca_deviation': ca_ca_dev, 'ca_ca_valid_percent': ca_ca_valid, 'num_ca_ca_clashes': np.sum(clashes), }