| | import copy |
| | import math |
| | import pickle |
| |
|
| | import scipy |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from scipy.linalg import sqrtm |
| | import re |
| |
|
| |
|
| |
|
| |
|
| |
|
| | def upgrade_state_dict(state_dict, prefixes=["encoder.sentence_encoder.", "encoder."]): |
| | """Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'.""" |
| | pattern = re.compile("^" + "|".join(prefixes)) |
| | state_dict = {pattern.sub("", name): param for name, param in state_dict.items()} |
| | return state_dict |
| |
|
| | def map_t_to_alpha(t, alpha_scale): |
| | """ |
| | Maps t in [0,1) to the range of alphas using the inverse CDF of an exponential distribution. |
| | |
| | Args: |
| | t (torch.Tensor): A tensor of values in [0,1). |
| | alpha_scale (float): The scaling factor used in the original alpha calculation. |
| | |
| | Returns: |
| | torch.Tensor: The corresponding alpha values. |
| | """ |
| | if torch.any(t >= 1) or torch.any(t < 0): |
| | raise ValueError("t must be in the range [0,1).") |
| | |
| | return 1 + (-torch.log(1 - t)) * alpha_scale |
| | |
| | |
| |
|
| | def load_flybrain_designed_seqs(path): |
| | order = {'A': 0, 'C':1, 'G':2, 'T':3} |
| | f = open(path, "rb") |
| | data = pickle.load(f) |
| | arrays = [] |
| | for seq in data['seq']: |
| | arrays.append([order[char] for char in seq]) |
| | return torch.tensor(arrays, dtype=torch.long) |
| |
|
| |
|
| | def update_ema(current_dict, prev_ema, gamma = 0.9): |
| | ema = copy.deepcopy(prev_ema) |
| | current_dict = copy.deepcopy(current_dict) |
| | for key, current_value in current_dict.items(): |
| | ema_key = 'ema_' + key |
| | if not np.isnan(current_value): |
| | if ema_key in prev_ema: |
| | ema[ema_key] = (1 - gamma) * current_value + gamma * prev_ema[ema_key] |
| | else: |
| | ema[ema_key] = current_value |
| | return ema |
| |
|
| | def min_max_str(x): |
| | return f'min {x.min()} max {x.max()}' |
| |
|
| | def get_wasserstein_dist(embeds1, embeds2): |
| | if np.isnan(embeds2).any() or np.isnan(embeds1).any() or len(embeds1) == 0 or len(embeds2) == 0: |
| | return float('nan') |
| | mu1, sigma1 = embeds1.mean(axis=0), np.cov(embeds1, rowvar=False) |
| | mu2, sigma2 = embeds2.mean(axis=0), np.cov(embeds2, rowvar=False) |
| | ssdiff = np.sum((mu1 - mu2) ** 2.0) |
| | covmean = sqrtm(sigma1.dot(sigma2)) |
| | if np.iscomplexobj(covmean): |
| | covmean = covmean.real |
| | dist = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) |
| | return dist |
| |
|
| | def simplex_proj(seq): |
| | """Algorithm from https://arxiv.org/abs/1309.1541 Weiran Wang, Miguel Á. Carreira-Perpiñán""" |
| | Y = seq.reshape(-1, seq.shape[-1]) |
| | N, K = Y.shape |
| | X, _ = torch.sort(Y, dim=-1, descending=True) |
| | X_cumsum = torch.cumsum(X, dim=-1) - 1 |
| | div_seq = torch.arange(1, K + 1, dtype=Y.dtype, device=Y.device) |
| | Xtmp = X_cumsum / div_seq.unsqueeze(0) |
| |
|
| | greater_than_Xtmp = (X > Xtmp).sum(dim=1, keepdim=True) |
| | row_indices = torch.arange(N, dtype=torch.long, device=Y.device).unsqueeze(1) |
| | selected_Xtmp = Xtmp[row_indices, greater_than_Xtmp - 1] |
| |
|
| | X = torch.max(Y - selected_Xtmp, torch.zeros_like(Y)) |
| | return X.view(seq.shape) |
| |
|
| |
|
| |
|
| | def batch_project_simplex(v): |
| | u, _ = torch.sort(v, dim=1, descending=True) |
| | cssv = u.cumsum(dim=1) |
| | k = torch.arange(1, v.shape[1] + 1, device=v.device) |
| | rho = ((u * k) > (cssv - 1)).int().cumsum(dim=1).argmax(dim=1) |
| | theta = (cssv[torch.arange(v.shape[0]), rho] - 1) / (rho + 1).float() |
| | w = torch.maximum(v - theta.unsqueeze(1), torch.tensor(0.0, device=v.device)) |
| | return w |
| |
|
| | if __name__ == "__main__": |
| | a = torch.softmax(torch.rand((5,4)), dim=-1) |
| | b = torch.rand((5,4)) - 1 |
| | ab = torch.cat([a,b]) |
| | ab_proj1 = batch_project_simplex(ab) |
| | ab_proj2 = simplex_proj(ab) |
| | print('ab_proj1 - ab_proj2',ab_proj1 - ab_proj2) |
| | print('ab_proj1 - ab', ab_proj1 - ab) |
| | print('ab_proj2.sum(-1)', ab_proj2.sum(-1)) |
| | print('ab_proj2', ab_proj2) |
| |
|
| | def sample_cond_prob_path(args, seq, alphabet_size): |
| | B, L = seq.shape |
| | seq_one_hot = torch.nn.functional.one_hot(seq, num_classes=alphabet_size) |
| | if args.mode == 'dirichlet': |
| | alphas = torch.from_numpy(1 + scipy.stats.expon().rvs(size=B) * args.alpha_scale).to(seq.device).float() |
| | if args.fix_alpha: |
| | alphas = torch.ones(B, device=seq.device) * args.fix_alpha |
| | alphas_ = torch.ones(B, L, alphabet_size, device=seq.device) |
| | alphas_ = alphas_ + seq_one_hot * (alphas[:,None,None] - 1) |
| | xt = torch.distributions.Dirichlet(alphas_).sample() |
| | elif args.mode == 'distill': |
| | alphas = torch.zeros(B, device=seq.device) |
| | xt = torch.distributions.Dirichlet(torch.ones(B, L, alphabet_size, device=seq.device)).sample() |
| | elif args.mode == 'riemannian': |
| | t = torch.rand(B, device=seq.device) |
| | dirichlet = torch.distributions.Dirichlet(torch.ones(alphabet_size, device=seq.device)) |
| | x0 = dirichlet.sample((B,L)) |
| | x1 = seq_one_hot |
| | xt = t[:,None,None] * x1 + (1 - t[:,None,None]) * x0 |
| | alphas = t |
| | elif args.mode == 'ardm' or args.mode == 'lrar': |
| | mask_prob = torch.rand(1, device=seq.device) |
| | mask = torch.rand(seq.shape, device=seq.device) < mask_prob |
| | if args.mode == 'lrar': mask = ~(torch.arange(L, device=seq.device) < (1-mask_prob) * L) |
| | xt = torch.where(mask, alphabet_size, seq) |
| | xt = torch.nn.functional.one_hot(xt, num_classes=alphabet_size + 1).float() |
| | alphas = mask_prob.expand(B) |
| | return xt, alphas |
| |
|
| | def expand_simplex(xt, alphas, prior_pseudocount): |
| | prior_weights = (prior_pseudocount / (alphas + prior_pseudocount - 1))[:, None, None] |
| | return torch.cat([xt * (1 - prior_weights), xt * prior_weights], -1), prior_weights |
| |
|
| |
|
| | class DirichletConditionalFlow: |
| | def __init__(self, K=20, alpha_min=1, alpha_max=100, alpha_spacing=0.01): |
| | self.alphas = np.arange(alpha_min, alpha_max + alpha_spacing, alpha_spacing) |
| | self.beta_cdfs = [] |
| | self.bs = np.linspace(0, 1, 1000) |
| | for alph in self.alphas: |
| | self.beta_cdfs.append(scipy.special.betainc(alph, K-1, self.bs)) |
| | self.beta_cdfs = np.array(self.beta_cdfs) |
| | self.beta_cdfs_derivative = np.diff(self.beta_cdfs, axis=0) / alpha_spacing |
| | self.K = K |
| |
|
| | def c_factor(self, bs, alpha): |
| | out1 = scipy.special.beta(alpha, self.K - 1) |
| | out2 = np.where(bs < 1, out1 / ((1 - bs) ** (self.K - 1)), 0) |
| | out = np.where((bs ** (alpha - 1)) > 0, out2 / (bs ** (alpha - 1)), 0) |
| | I_func = self.beta_cdfs_derivative[np.argmin(np.abs(alpha - self.alphas))] |
| | interp = -np.interp(bs, self.bs, I_func) |
| | final = interp * out |
| | return final |
| |
|
| |
|
| | class GaussianSmearing(torch.nn.Module): |
| | |
| | def __init__(self, start=0.0, stop=5.0, embedding_dim=50): |
| | super().__init__() |
| | offset = torch.linspace(start, stop, embedding_dim) |
| | self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2 |
| | self.register_buffer("offset", offset) |
| | self.embedding_dim = embedding_dim |
| |
|
| | def forward(self, signal): |
| | shape = signal.shape |
| | signal = signal.view(-1, 1) - self.offset.view(1, -1) + 1E-6 |
| | encoded = torch.exp(self.coeff * torch.pow(signal, 2)) |
| | return encoded.view(*shape, self.embedding_dim) |
| |
|
| |
|
| | class MonotonicFunction(torch.nn.Module): |
| | def __init__(self, init_max, num_bins): |
| | super().__init__() |
| | self.w = torch.nn.Parameter(torch.ones(num_bins) * np.log(init_max) - np.log(num_bins)) |
| | self.num_bins = num_bins |
| |
|
| | def forward(self, t): |
| | widths = torch.exp(self.w) |
| | right = torch.cumsum(widths, 0) |
| | left = right - widths |
| |
|
| | bin_idx = (t * self.num_bins).long() |
| | frac_part = t - bin_idx * (1 / self.num_bins) |
| |
|
| | return left[bin_idx] + (frac_part * self.num_bins) * (right[bin_idx] - left[bin_idx]) |
| |
|
| | def invert(self, f): |
| | widths = torch.exp(self.w) |
| | left = torch.cumsum(widths, 0) - widths |
| | bin_idx = (f.unsqueeze(-1) > left).sum(-1) - 1 |
| | frac_part = f - left[bin_idx] |
| | return bin_idx / self.num_bins + frac_part / widths[bin_idx] / self.num_bins |
| |
|
| | def derivative(self, t): |
| | widths = torch.exp(self.w) |
| | right = torch.cumsum(widths, 0) |
| | left = right - widths |
| | bin_idx = (t * self.num_bins).long() |
| | return (right[bin_idx] - left[bin_idx]) * self.num_bins |
| |
|
| | class SinusoidalEmbedding(nn.Module): |
| | """ from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py """ |
| | def __init__(self, embedding_dim, embedding_scale, max_positions=10000): |
| | super().__init__() |
| | self.embedding_dim = embedding_dim |
| | self.max_positions = max_positions |
| | self.embedding_scale = embedding_scale |
| |
|
| | def forward(self, signal): |
| | shape = signal.shape |
| | signal = signal.view(-1) * self.embedding_scale |
| | half_dim = self.embedding_dim // 2 |
| | emb = math.log(self.max_positions) / (half_dim - 1) |
| | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=signal.device) * -emb) |
| | emb = signal.float()[:, None] * emb[None, :] |
| | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
| | if self.embedding_dim % 2 == 1: |
| | emb = F.pad(emb, (0, 1), mode='constant') |
| | assert emb.shape == (signal.shape[0], self.embedding_dim) |
| | return emb.view(*shape, self.embedding_dim ) |
| |
|
| |
|
| | class GaussianFourierProjection(nn.Module): |
| | """Gaussian Fourier embeddings for noise levels. |
| | from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/models/layerspp.py#L32 |
| | """ |
| |
|
| | def __init__(self, embedding_dim=256, scale=1.0): |
| | super().__init__() |
| | self.W = nn.Parameter(torch.randn(embedding_dim//2) * scale, requires_grad=False) |
| | self.embedding_dim = embedding_dim |
| |
|
| | def forward(self, signal): |
| | shape = signal.shape |
| | signal = signal.view(-1) |
| | signal_proj = signal[:, None] * self.W[None, :] * 2 * np.pi |
| | emb = torch.cat([torch.sin(signal_proj), torch.cos(signal_proj)], dim=-1) |
| | return emb.view(*shape, self.embedding_dim ) |
| |
|
| | def get_signal_mapping(embedding_type, embedding_dim, embedding_scale=10000): |
| | if embedding_type == 'sinusoidal': |
| | emb_func = SinusoidalEmbedding(embedding_dim=embedding_dim, embedding_scale=embedding_scale) |
| | elif embedding_type == 'fourier': |
| | emb_func = GaussianFourierProjection(embedding_dim=embedding_dim, scale=embedding_scale) |
| | elif embedding_type == 'gaussian': |
| | emb_func = GaussianSmearing(0.0, 1, embedding_dim) |
| | else: |
| | raise NotImplemented |
| | return emb_func |
| |
|
| | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): |
| | """ |
| | Create a beta schedule that discretizes the given alpha_t_bar function, |
| | which defines the cumulative product of (1-beta) over time from t = [0,1]. |
| | |
| | :param num_diffusion_timesteps: the number of betas to produce. |
| | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and |
| | produces the cumulative product of (1-beta) up to that |
| | part of the diffusion process. |
| | :param max_beta: the maximum beta to use; use values lower than 1 to |
| | prevent singularities. |
| | """ |
| | betas = [] |
| | for i in range(num_diffusion_timesteps): |
| | t1 = i / num_diffusion_timesteps |
| | t2 = (i + 1) / num_diffusion_timesteps |
| | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) |
| | return np.array(betas) |
| |
|
| | def get_beta_schedule(num_steps): |
| |
|
| | return betas_for_alpha_bar( |
| | num_steps, |
| | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, |
| | ) |
| |
|
| |
|
| | class GaussianDiffusionSchedule: |
| | """ |
| | Utilities for training and sampling diffusion models. |
| | |
| | Ported directly from here, and then adapted over time to further experimentation. |
| | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 |
| | |
| | :param betas: a 1-D numpy array of betas for each diffusion timestep, |
| | starting at T and going to 1. |
| | :param model_mean_type: a ModelMeanType determining what the model outputs. |
| | :param model_var_type: a ModelVarType determining how variance is output. |
| | :param loss_type: a LossType determining the loss function to use. |
| | :param rescale_timesteps: if True, pass floating point timesteps into the |
| | model so that they are always scaled like in the |
| | original paper (0 to 1000). |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | timesteps, |
| | noise_scale=1.0, |
| | ): |
| | betas = get_beta_schedule(timesteps) |
| |
|
| | |
| | betas = np.array(betas, dtype=np.float64) |
| | self.betas = betas |
| | assert len(betas.shape) == 1, "betas must be 1-D" |
| | assert (betas > 0).all() and (betas <= 1).all() |
| |
|
| | self.timesteps = int(betas.shape[0]) |
| | self.noise_scale = noise_scale |
| |
|
| | alphas = 1.0 - betas |
| | self.alphas_cumprod = np.cumprod(alphas, axis=0) |
| | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) |
| | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) |
| | assert self.alphas_cumprod_prev.shape == (self.timesteps,) |
| |
|
| | |
| | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) |
| | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) |
| | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) |
| | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) |
| | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) |
| |
|
| | |
| | self.posterior_variance = ( |
| | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) |
| | ) |
| | |
| | |
| | self.posterior_log_variance_clipped = np.log( |
| | np.append(self.posterior_variance[1], self.posterior_variance[1:]) |
| | ) |
| | self.posterior_mean_coef1 = ( |
| | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) |
| | ) |
| | self.posterior_mean_coef2 = ( |
| | (1.0 - self.alphas_cumprod_prev) |
| | * np.sqrt(alphas) |
| | / (1.0 - self.alphas_cumprod) |
| | ) |
| |
|
| | def q_sample(self, x_start, t, noise=None): |
| | """ |
| | Diffuse the data for a given number of diffusion steps. |
| | |
| | In other words, sample from q(x_t | x_0). |
| | |
| | :param x_start: the initial data batch. |
| | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. |
| | :param noise: if specified, the split-out normal noise. |
| | :return: A noisy version of x_start. |
| | """ |
| | if noise is None: |
| | noise = self.noise_scale * torch.randn_like(x_start) |
| | |
| | assert noise.shape == x_start.shape |
| | return ( |
| | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start |
| | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) |
| | * noise |
| | ) |
| |
|
| | def q_posterior_mean_variance(self, x_start, x_t, t): |
| | """ |
| | Compute the mean and variance of the diffusion posterior: |
| | |
| | q(x_{t-1} | x_t, x_0) |
| | |
| | """ |
| | assert x_start.shape == x_t.shape |
| | posterior_mean = ( |
| | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start |
| | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t |
| | ) |
| | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) |
| | posterior_log_variance_clipped = _extract_into_tensor( |
| | self.posterior_log_variance_clipped, t, x_t.shape |
| | ) |
| |
|
| | posterior_variance = (self.noise_scale ** 2) * posterior_variance |
| | posterior_log_variance_clipped = 2 * np.log(self.noise_scale) + posterior_log_variance_clipped |
| |
|
| | assert ( |
| | posterior_mean.shape[0] |
| | == posterior_variance.shape[0] |
| | == posterior_log_variance_clipped.shape[0] |
| | == x_start.shape[0] |
| | ) |
| | return posterior_mean, posterior_variance, posterior_log_variance_clipped |
| |
|
| |
|
| | def _extract_into_tensor(arr, timesteps, broadcast_shape): |
| | """ |
| | Extract values from a 1-D numpy array for a batch of indices. |
| | |
| | :param arr: the 1-D numpy array. |
| | :param timesteps: a tensor of indices into the array to extract. |
| | :param broadcast_shape: a larger shape of K dimensions with the batch |
| | dimension equal to the length of timesteps. |
| | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. |
| | """ |
| | res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() |
| | while len(res.shape) < len(broadcast_shape): |
| | res = res[..., None] |
| | return res.expand(broadcast_shape) |
| |
|
| |
|
| | def space_timesteps(num_timesteps, section_counts): |
| | """ |
| | Create a list of timesteps to use from an original diffusion process, |
| | given the number of timesteps we want to take from equally-sized portions |
| | of the original process. |
| | |
| | For example, if there's 300 timesteps and the section counts are [10,15,20] |
| | then the first 100 timesteps are strided to be 10 timesteps, the second 100 |
| | are strided to be 15 timesteps, and the final 100 are strided to be 20. |
| | |
| | If the stride is a string starting with "ddim", then the fixed striding |
| | from the DDIM paper is used, and only one section is allowed. |
| | |
| | :param num_timesteps: the number of diffusion steps in the original |
| | process to divide up. |
| | :param section_counts: either a list of numbers, or a string containing |
| | comma-separated numbers, indicating the step count |
| | per section. As a special case, use "ddimN" where N |
| | is a number of steps to use the striding from the |
| | DDIM paper. |
| | :return: a set of diffusion steps from the original process to use. |
| | """ |
| | if isinstance(section_counts, str): |
| | if section_counts.startswith("ddim"): |
| | desired_count = int(section_counts[len("ddim"):]) |
| | for i in range(1, num_timesteps): |
| | if len(range(0, num_timesteps, i)) == desired_count: |
| | return set(range(0, num_timesteps, i)) |
| | raise ValueError( |
| | f"cannot create exactly {num_timesteps} steps with an integer stride" |
| | ) |
| | section_counts = [int(x) for x in section_counts.split(",")] |
| | size_per = num_timesteps // len(section_counts) |
| | extra = num_timesteps % len(section_counts) |
| | start_idx = 0 |
| | all_steps = [] |
| | for i, section_count in enumerate(section_counts): |
| | size = size_per + (1 if i < extra else 0) |
| | if size < section_count: |
| | raise ValueError( |
| | f"cannot divide section of {size} steps into {section_count}" |
| | ) |
| | if section_count <= 1: |
| | frac_stride = 1 |
| | else: |
| | frac_stride = (size - 1) / (section_count - 1) |
| | cur_idx = 0.0 |
| | taken_steps = [] |
| | for _ in range(section_count): |
| | taken_steps.append(start_idx + round(cur_idx)) |
| | cur_idx += frac_stride |
| | all_steps += taken_steps |
| | start_idx += size |
| | return set(all_steps) |
| |
|
| | def timestep_embedding(timesteps, dim, max_period=10000): |
| | """ |
| | Create sinusoidal timestep embeddings. |
| | |
| | :param timesteps: a 1-D Tensor of N indices, one per batch element. |
| | These may be fractional. |
| | :param dim: the dimension of the output. |
| | :param max_period: controls the minimum frequency of the embeddings. |
| | :return: an [N x dim] Tensor of positional embeddings. |
| | """ |
| | half = dim // 2 |
| | freqs = torch.exp( |
| | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half |
| | ).to(device=timesteps.device) |
| | args = timesteps[:, None].float() * freqs[None] |
| | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| | if dim % 2: |
| | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| | return embedding |