Spaces:
Sleeping
Sleeping
File size: 5,139 Bytes
dcacefd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
"""Utils for evaluating bond length."""
import collections
from typing import Tuple, Sequence, Dict, Optional
import numpy as np
from scipy import spatial as sci_spatial
import matplotlib.pyplot as plt
from utils.evaluation import eval_bond_length_config
import utils.data as utils_data
BondType = Tuple[int, int, int] # (atomic_num, atomic_num, bond_type)
BondLengthData = Tuple[BondType, float] # (bond_type, bond_length)
BondLengthProfile = Dict[BondType, np.ndarray] # bond_type -> empirical distribution
def get_distribution(distances: Sequence[float], bins=eval_bond_length_config.DISTANCE_BINS) -> np.ndarray:
"""Get the distribution of distances.
Args:
distances (list): List of distances.
bins (list): bins of distances
Returns:
np.array: empirical distribution of distances with length equals to DISTANCE_BINS.
"""
bin_counts = collections.Counter(np.searchsorted(bins, distances))
bin_counts = [bin_counts[i] if i in bin_counts else 0 for i in range(len(bins) + 1)]
bin_counts = np.array(bin_counts) / np.sum(bin_counts)
return bin_counts
def _format_bond_type(bond_type: BondType) -> BondType:
atom1, atom2, bond_category = bond_type
if atom1 > atom2:
atom1, atom2 = atom2, atom1
return atom1, atom2, bond_category
def get_bond_length_profile(bond_lengths: Sequence[BondLengthData]) -> BondLengthProfile:
bond_length_profile = collections.defaultdict(list)
for bond_type, bond_length in bond_lengths:
bond_type = _format_bond_type(bond_type)
bond_length_profile[bond_type].append(bond_length)
bond_length_profile = {k: get_distribution(v) for k, v in bond_length_profile.items()}
return bond_length_profile
def _bond_type_str(bond_type: BondType) -> str:
atom1, atom2, bond_category = bond_type
return f'{atom1}-{atom2}|{bond_category}'
def eval_bond_length_profile(bond_length_profile: BondLengthProfile) -> Dict[str, Optional[float]]:
metrics = {}
# Jensen-Shannon distances
for bond_type, gt_distribution in eval_bond_length_config.EMPIRICAL_DISTRIBUTIONS.items():
if bond_type not in bond_length_profile:
metrics[f'JSD_{_bond_type_str(bond_type)}'] = None
else:
metrics[f'JSD_{_bond_type_str(bond_type)}'] = sci_spatial.distance.jensenshannon(gt_distribution,
bond_length_profile[
bond_type])
return metrics
def get_pair_length_profile(pair_lengths):
cc_dist = [d[1] for d in pair_lengths if d[0] == (6, 6) and d[1] < 2]
all_dist = [d[1] for d in pair_lengths if d[1] < 12]
pair_length_profile = {
'CC_2A': get_distribution(cc_dist, bins=np.linspace(0, 2, 100)),
'All_12A': get_distribution(all_dist, bins=np.linspace(0, 12, 100))
}
return pair_length_profile
def eval_pair_length_profile(pair_length_profile):
metrics = {}
for k, gt_distribution in eval_bond_length_config.PAIR_EMPIRICAL_DISTRIBUTIONS.items():
if k not in pair_length_profile:
metrics[f'JSD_{k}'] = None
else:
metrics[f'JSD_{k}'] = sci_spatial.distance.jensenshannon(gt_distribution, pair_length_profile[k])
return metrics
def plot_distance_hist(pair_length_profile, metrics=None, save_path=None):
gt_profile = eval_bond_length_config.PAIR_EMPIRICAL_DISTRIBUTIONS
plt.figure(figsize=(6 * len(gt_profile), 4))
for idx, (k, gt_distribution) in enumerate(eval_bond_length_config.PAIR_EMPIRICAL_DISTRIBUTIONS.items()):
plt.subplot(1, len(gt_profile), idx + 1)
x = eval_bond_length_config.PAIR_EMPIRICAL_BINS[k]
plt.step(x, gt_profile[k][1:])
plt.step(x, pair_length_profile[k][1:])
plt.legend(['True', 'Learned'])
if metrics is not None:
plt.title(f'{k} JS div: {metrics["JSD_" + k]:.4f}')
else:
plt.title(k)
if save_path is not None:
plt.savefig(save_path)
else:
plt.show()
plt.close()
def pair_distance_from_pos_v(pos, elements):
pdist = pos[None, :] - pos[:, None]
pdist = np.sqrt(np.sum(pdist ** 2, axis=-1))
dist_list = []
for s in range(len(pos)):
for e in range(s + 1, len(pos)):
s_sym = elements[s]
e_sym = elements[e]
d = pdist[s, e]
dist_list.append(((s_sym, e_sym), d))
return dist_list
def bond_distance_from_mol(mol):
pos = mol.GetConformer().GetPositions()
pdist = pos[None, :] - pos[:, None]
pdist = np.sqrt(np.sum(pdist ** 2, axis=-1))
all_distances = []
for bond in mol.GetBonds():
s_sym = bond.GetBeginAtom().GetAtomicNum()
e_sym = bond.GetEndAtom().GetAtomicNum()
s_idx, e_idx = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
bond_type = utils_data.BOND_TYPES[bond.GetBondType()]
distance = pdist[s_idx, e_idx]
all_distances.append(((s_sym, e_sym, bond_type), distance))
return all_distances
|