Transcendental-Programmer
Refactor core logic: move and modularize all latent space, sampling, and utility code into faceforge_core/
e3af1ef
| from abc import abstractmethod | |
| import torch | |
| import numpy as np | |
| from .utils import recursive_find_device, recursive_find_dtype | |
| class EncodingSampler: | |
| """ | |
| Class to sample encodings given low dimensional spatial relationships. | |
| """ | |
| def __init__(self, encodes): | |
| self.encodes = encodes | |
| def apply_coefs(self, coefs): | |
| """ | |
| Linear combination of encodings given coefs | |
| """ | |
| device = recursive_find_device(self.encodes) | |
| dtype = recursive_find_dtype(self.encodes) | |
| # NOTE: Convert from float64 first to `dtype` and *then* to `device` to | |
| # prevent issues with certain devices not supporting f64 | |
| # (*cough cough* Apple) | |
| coefs = torch.from_numpy(coefs).to(dtype).to(device) | |
| def single_apply(encodes): | |
| if encodes is None: | |
| return None | |
| elif len(encodes.shape) == 3: | |
| return (coefs[:,None,None] * encodes).sum(0) | |
| elif len(encodes.shape) == 2: | |
| return (coefs[:,None] * encodes).sum(0) | |
| else: | |
| raise ValueError("Encoding Sampler couldn't figure out shape of encodings") | |
| if isinstance(self.encodes, list) or isinstance(self.encodes, tuple): | |
| return list(map(single_apply, self.encodes)) | |
| else: | |
| return single_apply(self.encodes) | |
| def __call__(self, point, other_points): | |
| """ | |
| :param point: Point in low space representing user input ([2,] array) | |
| :param other_points: Points in low space representing existing prompts ([N,2] array) | |
| """ | |
| pass | |
| class DistanceSampling(EncodingSampler): | |
| """ | |
| Sample based on distances between points in low dim space | |
| """ | |
| def __call__(self, point, other_points): | |
| coefs = 1. / ((1. + np.linalg.norm(point[None,:] - other_points, axis = 1) ** 2)) | |
| return self.apply_coefs(coefs) | |
| class CircleSampling(EncodingSampler): | |
| """ | |
| Sampler that views all encodings as points on a unit circle | |
| """ | |
| def __call__(self, point, other_points): | |
| # Idea: weight of points in same direction should be 1 | |
| # weight of points in opposite should be 0 | |
| cos_sims = point @ other_points.transpose() # [2] x [2, N] -> N | |
| # Negative values don't work, but we want something analagous for "negative signals" | |
| # tanh is like -x for low values, but then caps out at 1 | |
| #cos_sims = np.where(cos_sims<0, np.tanh(cos_sims), cos_sims) | |
| return self.apply_coefs(cos_sims) | |