Spaces:
Sleeping
Sleeping
| """Utils for evaluating bond length.""" | |
| import collections | |
| from typing import Tuple, Sequence, Dict, Optional | |
| import numpy as np | |
| from scipy import spatial as sci_spatial | |
| import matplotlib.pyplot as plt | |
| from utils.evaluation import eval_bond_length_config | |
| import utils.data as utils_data | |
| BondType = Tuple[int, int, int] # (atomic_num, atomic_num, bond_type) | |
| BondLengthData = Tuple[BondType, float] # (bond_type, bond_length) | |
| BondLengthProfile = Dict[BondType, np.ndarray] # bond_type -> empirical distribution | |
| def get_distribution(distances: Sequence[float], bins=eval_bond_length_config.DISTANCE_BINS) -> np.ndarray: | |
| """Get the distribution of distances. | |
| Args: | |
| distances (list): List of distances. | |
| bins (list): bins of distances | |
| Returns: | |
| np.array: empirical distribution of distances with length equals to DISTANCE_BINS. | |
| """ | |
| bin_counts = collections.Counter(np.searchsorted(bins, distances)) | |
| bin_counts = [bin_counts[i] if i in bin_counts else 0 for i in range(len(bins) + 1)] | |
| bin_counts = np.array(bin_counts) / np.sum(bin_counts) | |
| return bin_counts | |
| def _format_bond_type(bond_type: BondType) -> BondType: | |
| atom1, atom2, bond_category = bond_type | |
| if atom1 > atom2: | |
| atom1, atom2 = atom2, atom1 | |
| return atom1, atom2, bond_category | |
| def get_bond_length_profile(bond_lengths: Sequence[BondLengthData]) -> BondLengthProfile: | |
| bond_length_profile = collections.defaultdict(list) | |
| for bond_type, bond_length in bond_lengths: | |
| bond_type = _format_bond_type(bond_type) | |
| bond_length_profile[bond_type].append(bond_length) | |
| bond_length_profile = {k: get_distribution(v) for k, v in bond_length_profile.items()} | |
| return bond_length_profile | |
| def _bond_type_str(bond_type: BondType) -> str: | |
| atom1, atom2, bond_category = bond_type | |
| return f'{atom1}-{atom2}|{bond_category}' | |
| def eval_bond_length_profile(bond_length_profile: BondLengthProfile) -> Dict[str, Optional[float]]: | |
| metrics = {} | |
| # Jensen-Shannon distances | |
| for bond_type, gt_distribution in eval_bond_length_config.EMPIRICAL_DISTRIBUTIONS.items(): | |
| if bond_type not in bond_length_profile: | |
| metrics[f'JSD_{_bond_type_str(bond_type)}'] = None | |
| else: | |
| metrics[f'JSD_{_bond_type_str(bond_type)}'] = sci_spatial.distance.jensenshannon(gt_distribution, | |
| bond_length_profile[ | |
| bond_type]) | |
| return metrics | |
| def get_pair_length_profile(pair_lengths): | |
| cc_dist = [d[1] for d in pair_lengths if d[0] == (6, 6) and d[1] < 2] | |
| all_dist = [d[1] for d in pair_lengths if d[1] < 12] | |
| pair_length_profile = { | |
| 'CC_2A': get_distribution(cc_dist, bins=np.linspace(0, 2, 100)), | |
| 'All_12A': get_distribution(all_dist, bins=np.linspace(0, 12, 100)) | |
| } | |
| return pair_length_profile | |
| def eval_pair_length_profile(pair_length_profile): | |
| metrics = {} | |
| for k, gt_distribution in eval_bond_length_config.PAIR_EMPIRICAL_DISTRIBUTIONS.items(): | |
| if k not in pair_length_profile: | |
| metrics[f'JSD_{k}'] = None | |
| else: | |
| metrics[f'JSD_{k}'] = sci_spatial.distance.jensenshannon(gt_distribution, pair_length_profile[k]) | |
| return metrics | |
| def plot_distance_hist(pair_length_profile, metrics=None, save_path=None): | |
| gt_profile = eval_bond_length_config.PAIR_EMPIRICAL_DISTRIBUTIONS | |
| plt.figure(figsize=(6 * len(gt_profile), 4)) | |
| for idx, (k, gt_distribution) in enumerate(eval_bond_length_config.PAIR_EMPIRICAL_DISTRIBUTIONS.items()): | |
| plt.subplot(1, len(gt_profile), idx + 1) | |
| x = eval_bond_length_config.PAIR_EMPIRICAL_BINS[k] | |
| plt.step(x, gt_profile[k][1:]) | |
| plt.step(x, pair_length_profile[k][1:]) | |
| plt.legend(['True', 'Learned']) | |
| if metrics is not None: | |
| plt.title(f'{k} JS div: {metrics["JSD_" + k]:.4f}') | |
| else: | |
| plt.title(k) | |
| if save_path is not None: | |
| plt.savefig(save_path) | |
| else: | |
| plt.show() | |
| plt.close() | |
| def pair_distance_from_pos_v(pos, elements): | |
| pdist = pos[None, :] - pos[:, None] | |
| pdist = np.sqrt(np.sum(pdist ** 2, axis=-1)) | |
| dist_list = [] | |
| for s in range(len(pos)): | |
| for e in range(s + 1, len(pos)): | |
| s_sym = elements[s] | |
| e_sym = elements[e] | |
| d = pdist[s, e] | |
| dist_list.append(((s_sym, e_sym), d)) | |
| return dist_list | |
| def bond_distance_from_mol(mol): | |
| pos = mol.GetConformer().GetPositions() | |
| pdist = pos[None, :] - pos[:, None] | |
| pdist = np.sqrt(np.sum(pdist ** 2, axis=-1)) | |
| all_distances = [] | |
| for bond in mol.GetBonds(): | |
| s_sym = bond.GetBeginAtom().GetAtomicNum() | |
| e_sym = bond.GetEndAtom().GetAtomicNum() | |
| s_idx, e_idx = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() | |
| bond_type = utils_data.BOND_TYPES[bond.GetBondType()] | |
| distance = pdist[s_idx, e_idx] | |
| all_distances.append(((s_sym, e_sym, bond_type), distance)) | |
| return all_distances | |