Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from einops import reduce | |
| from tqdm.auto import tqdm | |
| from functools import partial | |
| from .transformer import Transformer | |
| from ..model_utils import default, identity, extract | |
| from .control import * | |
| def linear_beta_schedule(timesteps): | |
| scale = 1000 / timesteps | |
| beta_start = scale * 0.0001 | |
| beta_end = scale * 0.02 | |
| return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) | |
| def cosine_beta_schedule(timesteps, s=0.008): | |
| """ | |
| cosine schedule | |
| as proposed in https://openreview.net/forum?id=-NEXDKk8gZ | |
| """ | |
| steps = timesteps + 1 | |
| x = torch.linspace(0, timesteps, steps, dtype=torch.float64) | |
| alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 | |
| alphas_cumprod = alphas_cumprod / alphas_cumprod[0] | |
| betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) | |
| return torch.clip(betas, 0, 0.999) | |
| class Tiffusion(nn.Module): | |
| def __init__( | |
| self, | |
| seq_length, | |
| feature_size, | |
| n_layer_enc=3, | |
| n_layer_dec=6, | |
| d_model=None, | |
| timesteps=1000, | |
| sampling_timesteps=None, | |
| loss_type="l1", | |
| beta_schedule="cosine", | |
| n_heads=4, | |
| mlp_hidden_times=4, | |
| eta=0.0, | |
| attn_pd=0.0, | |
| resid_pd=0.0, | |
| kernel_size=None, | |
| padding_size=None, | |
| use_ff=True, | |
| reg_weight=None, | |
| control_signal={}, | |
| moving_average=False, | |
| **kwargs, | |
| ): | |
| super(Tiffusion, self).__init__() | |
| self.eta, self.use_ff = eta, use_ff | |
| self.seq_length = seq_length | |
| self.feature_size = feature_size | |
| self.ff_weight = default(reg_weight, math.sqrt(self.seq_length) / 5) | |
| self.sum_weight = default(reg_weight, math.sqrt(self.seq_length // 10) / 50) | |
| self.training_control_signal = control_signal # training control signal | |
| self.moving_average = moving_average | |
| self.model: Transformer = Transformer( | |
| n_feat=feature_size, | |
| n_channel=seq_length, | |
| n_layer_enc=n_layer_enc, | |
| n_layer_dec=n_layer_dec, | |
| n_heads=n_heads, | |
| attn_pdrop=attn_pd, | |
| resid_pdrop=resid_pd, | |
| mlp_hidden_times=mlp_hidden_times, | |
| max_len=seq_length, | |
| n_embd=d_model, | |
| conv_params=[kernel_size, padding_size], | |
| **kwargs, | |
| ) | |
| if beta_schedule == "linear": | |
| betas = linear_beta_schedule(timesteps) | |
| elif beta_schedule == "cosine": | |
| betas = cosine_beta_schedule(timesteps) | |
| else: | |
| raise ValueError(f"unknown beta schedule {beta_schedule}") | |
| alphas = 1.0 - betas | |
| alphas_cumprod = torch.cumprod(alphas, dim=0) | |
| alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) | |
| (timesteps,) = betas.shape | |
| self.num_timesteps = int(timesteps) | |
| self.loss_type = loss_type | |
| # sampling related parameters | |
| self.sampling_timesteps = default( | |
| sampling_timesteps, timesteps | |
| ) # default num sampling timesteps to number of timesteps at training | |
| assert self.sampling_timesteps <= timesteps | |
| self.fast_sampling = self.sampling_timesteps < timesteps | |
| # helper function to register buffer from float64 to float32 | |
| register_buffer = lambda name, val: self.register_buffer( | |
| name, val.to(torch.float32) | |
| ) | |
| register_buffer("betas", betas) | |
| register_buffer("alphas_cumprod", alphas_cumprod) | |
| register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) | |
| # calculations for diffusion q(x_t | x_{t-1}) and others | |
| register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) | |
| register_buffer( | |
| "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod) | |
| ) | |
| register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)) | |
| register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)) | |
| register_buffer( | |
| "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1) | |
| ) | |
| # calculations for posterior q(x_{t-1} | x_t, x_0) | |
| posterior_variance = ( | |
| betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) | |
| ) | |
| # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) | |
| register_buffer("posterior_variance", posterior_variance) | |
| # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain | |
| register_buffer( | |
| "posterior_log_variance_clipped", | |
| torch.log(posterior_variance.clamp(min=1e-20)), | |
| ) | |
| register_buffer( | |
| "posterior_mean_coef1", | |
| betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), | |
| ) | |
| register_buffer( | |
| "posterior_mean_coef2", | |
| (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod), | |
| ) | |
| # calculate reweighting | |
| register_buffer( | |
| "loss_weight", | |
| torch.sqrt(alphas) * torch.sqrt(1.0 - alphas_cumprod) / betas / 100, | |
| ) | |
| def predict_noise_from_start(self, x_t, t, x0): | |
| return ( | |
| extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0 | |
| ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) | |
| def predict_start_from_noise(self, x_t, t, noise): | |
| return ( | |
| extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t | |
| - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise | |
| ) | |
| def q_posterior(self, x_start, x_t, t): | |
| posterior_mean = ( | |
| extract(self.posterior_mean_coef1, t, x_t.shape) * x_start | |
| + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t | |
| ) | |
| posterior_variance = extract(self.posterior_variance, t, x_t.shape) | |
| posterior_log_variance_clipped = extract( | |
| self.posterior_log_variance_clipped, t, x_t.shape | |
| ) | |
| return posterior_mean, posterior_variance, posterior_log_variance_clipped | |
| def output(self, x, t, padding_masks=None, control_signal=None): | |
| trend, season = self.model( | |
| x, t, padding_masks=padding_masks, control_signal=control_signal | |
| ) | |
| model_output = trend + season | |
| return model_output | |
| def model_predictions( | |
| self, x, t, clip_x_start=False, padding_masks=None, control_signal=None | |
| ): | |
| if padding_masks is None: | |
| padding_masks = torch.ones( | |
| x.shape[0], self.seq_length, dtype=bool, device=x.device | |
| ) | |
| maybe_clip = ( | |
| partial(torch.clamp, min=-1.0, max=1.0) if clip_x_start else identity | |
| ) | |
| x_start = self.output(x, t, padding_masks, control_signal=control_signal) | |
| x_start = maybe_clip(x_start) | |
| pred_noise = self.predict_noise_from_start(x, t, x_start) | |
| return pred_noise, x_start | |
| def p_mean_variance(self, x, t, clip_denoised=True, control_signal=None): | |
| _, x_start = self.model_predictions(x, t, control_signal=control_signal) | |
| if clip_denoised: | |
| x_start.clamp_(-1.0, 1.0) | |
| model_mean, posterior_variance, posterior_log_variance = self.q_posterior( | |
| x_start=x_start, x_t=x, t=t | |
| ) | |
| return model_mean, posterior_variance, posterior_log_variance, x_start | |
| def p_sample(self, x, t: int, clip_denoised=True, control_signal=None): | |
| batched_times = torch.full((x.shape[0],), t, device=x.device, dtype=torch.long) | |
| model_mean, _, model_log_variance, x_start = self.p_mean_variance( | |
| x=x, t=batched_times, clip_denoised=clip_denoised, control_signal=control_signal | |
| ) | |
| noise = torch.randn_like(x) if t > 0 else 0.0 # no noise if t == 0 | |
| pred_img = model_mean + (0.5 * model_log_variance).exp() * noise | |
| return pred_img, x_start | |
| def sample(self, shape, control_signal=None): | |
| device = self.betas.device | |
| img = torch.randn(shape, device=device) | |
| for t in tqdm( | |
| reversed(range(0, self.num_timesteps)), | |
| desc="sampling loop time step", | |
| total=self.num_timesteps, | |
| ): | |
| img, _ = self.p_sample(img, t, control_signal=control_signal) | |
| return img | |
| def fast_sample(self, shape, clip_denoised=True, model_kwargs=None, | |
| ): | |
| batch, device, total_timesteps, sampling_timesteps, eta = ( | |
| shape[0], | |
| self.betas.device, | |
| self.num_timesteps, | |
| self.sampling_timesteps, | |
| self.eta, | |
| ) | |
| # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps | |
| times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) | |
| times = list(reversed(times.int().tolist())) | |
| time_pairs = list( | |
| zip(times[:-1], times[1:]) | |
| ) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] | |
| img = torch.randn(shape, device=device) | |
| for time, time_next in tqdm(time_pairs, desc="sampling loop time step"): | |
| time_cond = torch.full((batch,), time, device=device, dtype=torch.long) | |
| pred_noise, x_start, *_ = self.model_predictions( | |
| img, time_cond, clip_x_start=clip_denoised, | |
| control_signal=model_kwargs.get("model_control_signal", {}) if model_kwargs else {} | |
| ) | |
| if time_next < 0: | |
| img = x_start | |
| continue | |
| alpha = self.alphas_cumprod[time] | |
| alpha_next = self.alphas_cumprod[time_next] | |
| sigma = ( | |
| eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() | |
| ) | |
| c = (1 - alpha_next - sigma**2).sqrt() | |
| noise = torch.randn_like(img) | |
| img = x_start * alpha_next.sqrt() + c * pred_noise + sigma * noise | |
| return img | |
| def generate_mts(self, batch_size=16): | |
| feature_size, seq_length = self.feature_size, self.seq_length | |
| sample_fn = self.fast_sample if self.fast_sampling else self.sample | |
| return sample_fn((batch_size, seq_length, feature_size)) | |
| def generate_mts_infill(self, target, partial_mask=None, clip_denoised=True, model_kwargs=None): | |
| sample_fn = self.fast_sample_infill_float_mask # if self.fast_sampling else self.sample_infill | |
| print("model_kwargs", model_kwargs) | |
| print("partial_mask", partial_mask.shape) | |
| print("target", target.shape) | |
| return sample_fn( | |
| shape=target.shape, | |
| target=target, | |
| sampling_timesteps=self.sampling_timesteps, | |
| partial_mask=partial_mask, | |
| clip_denoised=clip_denoised, | |
| model_kwargs=model_kwargs | |
| ) | |
| def loss_fn(self): | |
| if self.loss_type == "l1": | |
| return F.l1_loss | |
| elif self.loss_type == "l2": | |
| return F.mse_loss | |
| else: | |
| raise ValueError(f"invalid loss type {self.loss_type}") | |
| def q_sample(self, x_start, t, noise=None): | |
| noise = default(noise, lambda: torch.randn_like(x_start)) | |
| return ( | |
| extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
| + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise | |
| ) | |
| def calculate_dynamic_window(self, t: torch.Tensor) -> torch.Tensor: | |
| windows = ((5 ** ( t / 500)) * 15 // 5 - 2).long() | |
| return windows | |
| def _train_loss( | |
| self, | |
| x_start, | |
| t, | |
| target=None, | |
| noise=None, | |
| padding_masks=None, | |
| control_signal=None, | |
| ): | |
| noise = default(noise, lambda: torch.randn_like(x_start)) | |
| if target is None: | |
| target = x_start | |
| x = self.q_sample(x_start=x_start, t=t, noise=noise) # noise sample | |
| model_out = self.output(x, t, padding_masks, control_signal=control_signal) | |
| # if self.moving_average: | |
| # target = self.torch_moving_average(target.cpu(), t.cpu()).to(model_out.device) | |
| train_loss = self.loss_fn(model_out, target, reduction="none") | |
| fourier_loss = torch.tensor([0.0]) | |
| if self.use_ff: | |
| fft1 = torch.fft.fft(model_out.transpose(1, 2), norm="forward") | |
| fft2 = torch.fft.fft(target.transpose(1, 2), norm="forward") | |
| fft1, fft2 = fft1.transpose(1, 2), fft2.transpose(1, 2) | |
| fourier_loss = self.loss_fn( | |
| torch.real(fft1), torch.real(fft2), reduction="none" | |
| ) + self.loss_fn(torch.imag(fft1), torch.imag(fft2), reduction="none") | |
| train_loss += self.ff_weight * fourier_loss | |
| train_loss = reduce(train_loss, "b ... -> b (...)", "mean") | |
| train_loss = train_loss * extract(self.loss_weight, t, train_loss.shape) | |
| return train_loss.mean() | |
| # fmt: off | |
| def forward(self, x, **kwargs): | |
| b, c, n, device, feature_size, = *x.shape, x.device, self.feature_size | |
| assert n == feature_size, f'number of variable must be {feature_size}' | |
| t = torch.randint(0, self.num_timesteps, (b,), device=device).long() | |
| return self._train_loss(x_start=x, t=t, **kwargs) | |
| def return_components(self, x, t: int): | |
| b, c, n, device, feature_size, = *x.shape, x.device, self.feature_size | |
| assert n == feature_size, f'number of variable must be {feature_size}' | |
| t = torch.tensor([t]) | |
| t = t.repeat(b).to(device) | |
| x = self.q_sample(x, t) | |
| trend, season, residual = self.model(x, t, return_res=True) | |
| return trend, season, residual, x | |
| # fmt: on | |
| def fast_sample_infill_float_mask( | |
| self, | |
| shape, | |
| target: torch.Tensor, # target time series # [B, L, C] | |
| sampling_timesteps, | |
| partial_mask: torch.Tensor = None, # float mask between 0 and 1 # [B, L, C] | |
| clip_denoised=True, | |
| model_kwargs=None, | |
| ): | |
| batch, device, total_timesteps, eta = ( | |
| shape[0], | |
| self.betas.device, | |
| self.num_timesteps, | |
| self.eta, | |
| ) | |
| # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps | |
| times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) | |
| times = list(reversed(times.int().tolist())) | |
| time_pairs = list( | |
| zip(times[:-1], times[1:]) | |
| ) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] | |
| # Initialize with noise | |
| img = torch.randn(shape, device=device) # [B, L, C] | |
| for time, time_next in tqdm( | |
| time_pairs, desc="conditional sampling loop time step" | |
| ): | |
| time_cond = torch.full((batch,), time, device=device, dtype=torch.long) | |
| pred_noise, x_start, *_ = self.model_predictions( | |
| img, | |
| time_cond, | |
| clip_x_start=clip_denoised, | |
| control_signal=model_kwargs.get("model_control_signal", {}), | |
| ) | |
| if time_next < 0: | |
| img = x_start | |
| continue | |
| # Compute the predicted mean | |
| alpha = self.alphas_cumprod[time] | |
| alpha_next = self.alphas_cumprod[time_next] | |
| sigma = ( | |
| eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() | |
| ) | |
| c = (1 - alpha_next - sigma**2).sqrt() | |
| noise = torch.randn_like(img) | |
| pred_mean = x_start * alpha_next.sqrt() + c * pred_noise | |
| img = pred_mean + sigma * noise | |
| # Langevin Dynamics part for additional gradient updates | |
| img = self.langevin_fn( | |
| sample=img, | |
| mean=pred_mean, | |
| sigma=sigma, | |
| t=time_cond, | |
| tgt_embs=target, | |
| partial_mask=partial_mask, | |
| enable_float_mask=True, | |
| **model_kwargs, | |
| ) | |
| img = img * (1 - partial_mask) + target * partial_mask | |
| img = img * (1 - partial_mask) + target * partial_mask | |
| return img | |
| def sample_infill( | |
| self, | |
| shape, | |
| target, | |
| partial_mask=None, | |
| clip_denoised=True, | |
| model_kwargs=None, | |
| ): | |
| """ | |
| Generate samples from the model and yield intermediate samples from | |
| each timestep of diffusion. | |
| """ | |
| batch, device = shape[0], self.betas.device | |
| img = torch.randn(shape, device=device) | |
| for t in tqdm( | |
| reversed(range(0, self.num_timesteps)), | |
| desc="conditional sampling loop time step", | |
| total=self.num_timesteps, | |
| ): | |
| img = self.p_sample_infill( | |
| x=img, | |
| t=t, | |
| clip_denoised=clip_denoised, | |
| target=target, | |
| partial_mask=partial_mask, | |
| model_kwargs=model_kwargs, | |
| ) | |
| img[partial_mask] = target[partial_mask] | |
| return img | |
| def p_sample_infill( | |
| self, | |
| x, | |
| target, | |
| t: int, | |
| partial_mask=None, | |
| clip_denoised=True, | |
| model_kwargs=None, | |
| ): | |
| b, *_, device = *x.shape, self.betas.device | |
| batched_times = torch.full((x.shape[0],), t, device=x.device, dtype=torch.long) | |
| model_mean, _, model_log_variance, _ = self.p_mean_variance( | |
| x=x, t=batched_times, clip_denoised=clip_denoised, control_signal=model_kwargs.get("model_control_signal", {}) | |
| # don't pass parameters to control signal, for model itself | |
| # Otherwise pass: control_signal=model_kwargs.get("control_signal", {}) | |
| ) | |
| noise = torch.randn_like(x) if t > 0 else 0.0 # no noise if t == 0 | |
| sigma = (0.5 * model_log_variance).exp() | |
| pred_img = model_mean + sigma * noise | |
| pred_img = self.langevin_fn( | |
| sample=pred_img, | |
| mean=model_mean, | |
| sigma=sigma, | |
| t=batched_times, | |
| tgt_embs=target, | |
| partial_mask=partial_mask, | |
| # control_signal=model_kwargs.get("gradient_control_signal", {}), | |
| **model_kwargs, | |
| ) | |
| # fix point (must passed points) | |
| target_t = self.q_sample(target, t=batched_times) | |
| pred_img[partial_mask] = target_t[partial_mask] | |
| return pred_img | |
| def classifier_guidance( | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| y: torch.Tensor, | |
| classifier: torch.nn.Module | |
| ): | |
| with torch.enable_grad(): | |
| # 激活梯度计算 | |
| x_with_grad = x.detach().requires_grad_(True) | |
| # 获取 log 形式的概率分布 | |
| logits = classifier(x_with_grad, t) | |
| log_prob = F.log_softmax(logits, dim=-1) | |
| # 选取出 y 对应的项 | |
| selected = log_prob[range(len(logits)), y.view(-1)] | |
| # 计算梯度 | |
| return torch.autograd.grad(selected.sum(), x_with_grad)[0] | |
| def langevin_fn( | |
| self, | |
| coef, | |
| partial_mask, | |
| tgt_embs, | |
| learning_rate, | |
| sample, | |
| mean, | |
| sigma, | |
| t, | |
| coef_=0.0, | |
| gradient_control_signal={}, | |
| model_control_signal={}, | |
| **kwargs, | |
| ): | |
| # we thus run more gradient updates at large diffusion step t to guide the generation then | |
| # reduce the number of gradient steps in stages to accelerate sampling. | |
| if t[0].item() < self.num_timesteps * 0.02 : | |
| K = 0 | |
| elif t[0].item() > self.num_timesteps * 0.9: | |
| K = 3 | |
| elif t[0].item() > self.num_timesteps * 0.75: | |
| K = 2 | |
| learning_rate = learning_rate * 0.5 | |
| else: | |
| K = 1 | |
| learning_rate = learning_rate * 0.25 | |
| input_embs_param = torch.nn.Parameter(sample) | |
| # 获取时间相关的权重调整因子 | |
| time_weight = get_time_dependent_weights(t[0], self.num_timesteps) | |
| with torch.enable_grad(): | |
| for iteration in range(K): | |
| # x_i+1 = x_i + noise * grad(logp(x_i)) + sqrt(2*noise) * z_i | |
| optimizer = torch.optim.Adagrad([input_embs_param], lr=learning_rate) | |
| optimizer.zero_grad() | |
| x_start = self.output( | |
| x=input_embs_param, | |
| t=t, | |
| control_signal=model_control_signal, | |
| ) | |
| if sigma.mean() == 0: | |
| logp_term = ( | |
| coef * ((mean - input_embs_param) ** 2 / 1.0).mean(dim=0).sum() | |
| ) | |
| # determine the partical_mask is float | |
| if kwargs.get("enable_float_mask", False): | |
| infill_loss = (x_start * (partial_mask) - tgt_embs * (partial_mask)) ** 2 | |
| else: | |
| infill_loss = (x_start[partial_mask] - tgt_embs[partial_mask]) ** 2 | |
| infill_loss = infill_loss.mean(dim=0).sum() | |
| else: | |
| logp_term = ( | |
| coef | |
| * ((mean - input_embs_param) ** 2 / sigma).mean(dim=0).sum() | |
| ) | |
| if kwargs.get("enable_float_mask", False): | |
| infill_loss = (x_start * (partial_mask) - tgt_embs * (partial_mask)) ** 2 | |
| else: | |
| infill_loss = (x_start[partial_mask] - tgt_embs[partial_mask]) ** 2 | |
| infill_loss = (infill_loss / sigma.mean()).mean(dim=0).sum() | |
| gradient_scale = gradient_control_signal.get("gradient_scale", 1.0) # 全局梯度缩放因子 | |
| control_loss = 0 | |
| auc_sum, peak_points, bar_regions, target_freq = \ | |
| gradient_control_signal.get("auc"), gradient_control_signal.get("peak_points"), gradient_control_signal.get("bar_regions"), gradient_control_signal.get("target_freq") | |
| # 1. 原有的sum控制 | |
| if auc_sum is not None: | |
| sum_weight = gradient_control_signal.get("auc_weight", 1.0) * time_weight | |
| auc_loss = - sum_weight * sum_guidance( | |
| x=input_embs_param, | |
| t=t, | |
| target_sum=auc_sum, | |
| gradient_scale=gradient_scale, | |
| segments=gradient_control_signal.get("segments", ()) | |
| ) | |
| control_loss += auc_loss | |
| # 峰值引导 | |
| if peak_points is not None: | |
| peak_weight = gradient_control_signal.get("peak_weight", 1.0) * time_weight | |
| peak_loss = - peak_weight * peak_guidance( | |
| x=input_embs_param, | |
| t=t, | |
| peak_points=peak_points, | |
| window_size=gradient_control_signal.get("peak_window_size", 5), | |
| alpha_1=gradient_control_signal.get("peak_alpha_1", 1.2), | |
| gradient_scale=gradient_scale | |
| ) | |
| control_loss += peak_loss | |
| # 区间引导 | |
| if bar_regions is not None: | |
| bar_weight = gradient_control_signal.get("bar_weight", 1.0) * time_weight | |
| bar_loss = -bar_weight * bar_guidance( | |
| x=input_embs_param, | |
| t=t, | |
| bar_regions=bar_regions, | |
| gradient_scale=gradient_scale | |
| ) | |
| control_loss += bar_loss | |
| # 频率引导 | |
| if target_freq is not None: | |
| freq_weight = gradient_control_signal.get("freq_weight", 1.0) * time_weight | |
| freq_loss = -freq_weight * frequency_guidance( | |
| x=input_embs_param, | |
| t=t, | |
| target_freq=target_freq, | |
| freq_weight=freq_weight, | |
| gradient_scale=gradient_scale | |
| ) | |
| control_loss += freq_loss | |
| loss = logp_term + infill_loss + control_loss | |
| loss.backward() | |
| optimizer.step() | |
| torch.nn.utils.clip_grad_norm_([input_embs_param], gradient_control_signal.get("max_grad_norm", 1.0)) | |
| epsilon = torch.randn_like(input_embs_param.data) | |
| noise_scale = coef_ * sigma.mean().item() | |
| input_embs_param = torch.nn.Parameter( | |
| ( | |
| input_embs_param.data + noise_scale * epsilon | |
| ).detach() | |
| ) | |
| if kwargs.get("enable_float_mask", False): | |
| sample = sample * partial_mask + input_embs_param.data * (1 - partial_mask) | |
| else: | |
| sample[~partial_mask] = input_embs_param.data[~partial_mask] | |
| return sample | |
| def predict_weighted_points( | |
| self, | |
| observed_points: torch.Tensor, | |
| observed_mask: torch.Tensor, | |
| coef=1e-1, | |
| stepsize=1e-1, | |
| sampling_steps=50, | |
| **kargs, | |
| ): | |
| model_kwargs = {} | |
| model_kwargs["coef"] = coef | |
| model_kwargs["learning_rate"] = stepsize | |
| model_kwargs = {**model_kwargs, **kargs} | |
| assert len(observed_points.shape) == 2, "observed_points should be 2D, batch size = 1" | |
| x = observed_points.unsqueeze(0) | |
| float_mask = observed_mask.unsqueeze(0) # x != 0, 1 for observed, 0 for missing, bool tensor | |
| binary_mask = float_mask.clone() | |
| binary_mask[binary_mask > 0] = 1 | |
| x = x * 2 - 1 # normalize | |
| self.device = x.device | |
| x, float_mask, binary_mask = x.to(self.device), float_mask.to(self.device), binary_mask.to(self.device) | |
| if sampling_steps == self.num_timesteps: | |
| print("normal sampling") | |
| raise NotImplementedError | |
| sample = self.ema.ema_model.sample_infill_float_mask( | |
| shape=x.shape, | |
| target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing | |
| partial_mask=float_mask, | |
| model_kwargs=model_kwargs, | |
| ) | |
| # x: partially noise : (batch_size, seq_length, feature_dim) | |
| else: | |
| print("fast sampling") | |
| sample = self.fast_sample_infill_float_mask( | |
| shape=x.shape, | |
| target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing | |
| partial_mask=float_mask, | |
| model_kwargs=model_kwargs, | |
| sampling_timesteps=sampling_steps, | |
| ) | |
| # unnormalize | |
| sample = (sample + 1) / 2 | |
| return sample.squeeze(0).detach().cpu().numpy() | |