Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from scipy.spatial.distance import cdist | |
| from scipy.stats import gaussian_kde | |
| class cKDE: | |
| def __init__( | |
| self, embedding, semantics, metric="euclidean", scale_method="neff", scale=2000 | |
| ): | |
| self.metric = metric | |
| self.scale_method = scale_method | |
| self.scale = scale | |
| self.H = embedding | |
| self.Z = semantics | |
| def _quantile_scale(self, Z_cond_dist): | |
| return np.quantile(Z_cond_dist, self.scale) | |
| def _neff_scale(self, Z_cond_dist): | |
| scales = np.linspace(1e-02, 0.4, 100)[:, None] | |
| _Z_cond_dist = np.tile(Z_cond_dist, (len(scales), 1)) | |
| weights = np.exp(-(_Z_cond_dist**2) / (2 * scales**2)) | |
| neff = (np.sum(weights, axis=1) ** 2) / np.sum(weights**2, axis=1) | |
| diff = np.abs(neff - self.scale) | |
| scale_idx = np.argmin(diff) | |
| return scales[scale_idx].item() | |
| def _sample(self, z, cond_idx, m): | |
| sample_idx = list(set(range(len(z))) - set(cond_idx)) | |
| sample_z = np.tile(z, (m, 1)) | |
| if len(sample_idx) > 0: | |
| kde, _ = self.kde(z, cond_idx) | |
| sample_z[:, sample_idx] = kde.resample(m).T | |
| return sample_z | |
| def kde(self, z, cond_idx): | |
| sample_idx = list(set(range(len(z))) - set(cond_idx)) | |
| Z_sample = self.Z[:, sample_idx] | |
| Z_cond = self.Z[:, cond_idx] | |
| z_cond = z[cond_idx] | |
| Z_cond_dist = cdist(z_cond.reshape(1, -1), Z_cond, self.metric).squeeze() | |
| if self.scale_method == "constant": | |
| scale = self.scale | |
| if self.scale_method == "quantile": | |
| scale = self._quantile_scale(Z_cond_dist) | |
| elif self.scale_method == "neff": | |
| scale = self._neff_scale(Z_cond_dist) | |
| weights = np.exp(-(Z_cond_dist**2) / (2 * scale**2)) | |
| return gaussian_kde(Z_sample.T, weights=weights), scale | |
| def nearest_neighbor(self, z): | |
| dist = cdist(z, self.Z, metric=self.metric) | |
| return np.argmin(dist, axis=-1) | |
| def sample(self, z, cond_idx, m=1): | |
| if z.ndim == 1: | |
| z = z.reshape(1, -1) | |
| sample_z = np.concatenate([self._sample(_z, cond_idx, m) for _z in z], axis=0) | |
| nn_idx = self.nearest_neighbor(sample_z) | |
| sample_h = self.H[nn_idx] | |
| return sample_z, sample_h | |