|
|
| import torch
|
| import torch.nn.functional as F
|
|
|
| class LinearNoiseScheduler:
|
| def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
|
| self.num_timesteps = num_timesteps
|
|
|
|
|
| self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
|
|
|
|
|
| self.alphas = 1.0 - self.betas
|
| self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
|
|
|
|
|
| self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
| self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
|
|
|
|
|
| self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
|
| self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
|
|
|
|
|
| self.timesteps = torch.arange(0, num_timesteps).flip(0)
|
|
|
| def set_timesteps(self, num_inference_steps, device=None):
|
| """
|
| Thiết lập các timestep rời rạc được sử dụng cho chuỗi diffusion.
|
| """
|
| device_to_use = device if device is not None else self.betas.device
|
| self.timesteps = torch.linspace(self.num_timesteps - 1, 0, num_inference_steps, dtype=torch.long, device=device_to_use)
|
|
|
| def to(self, device):
|
| """Chuyển tất cả các tensor của scheduler sang một thiết bị cụ thể."""
|
| self.betas = self.betas.to(device)
|
| self.alphas = self.alphas.to(device)
|
| self.alphas_cumprod = self.alphas_cumprod.to(device)
|
| self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
|
| self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
|
| self.alphas_cumprod_prev = self.alphas_cumprod_prev.to(device)
|
| self.posterior_variance = self.posterior_variance.to(device)
|
| return self
|
|
|
| def add_noise(self, original_samples, noise, timesteps):
|
| """Thêm nhiễu vào mẫu gốc tại các bước thời gian t."""
|
| sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod.to(timesteps.device)[timesteps].view(-1, 1, 1, 1)
|
| sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod.to(timesteps.device)[timesteps].view(-1, 1, 1, 1)
|
|
|
| noisy_samples = sqrt_alphas_cumprod_t * original_samples + sqrt_one_minus_alphas_cumprod_t * noise
|
| return noisy_samples
|
|
|
| def step(self, model_output, timestep, sample):
|
| t = timestep
|
| alpha_t = self.alphas[t]
|
| alpha_bar_t = self.alphas_cumprod[t]
|
| sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alphas_cumprod[t]
|
| pred_original_sample = (sample - sqrt_one_minus_alpha_bar_t * model_output) / torch.sqrt(alpha_bar_t)
|
| pred_original_sample = torch.clamp(pred_original_sample, -1., 1.)
|
| if t == 0:
|
| return pred_original_sample
|
| alpha_bar_t_prev = self.alphas_cumprod_prev[t]
|
| posterior_variance_t = self.posterior_variance[t]
|
| pred_sample_direction = torch.sqrt(alpha_bar_t_prev) * self.betas[t] / (1. - alpha_bar_t)
|
| prev_sample_mean = torch.sqrt(alpha_t) * (1. - alpha_bar_t_prev) / (1. - alpha_bar_t) * sample + pred_sample_direction * pred_original_sample
|
| noise = torch.randn_like(model_output) if t > 0 else torch.zeros_like(model_output)
|
| prev_sample = prev_sample_mean + torch.sqrt(posterior_variance_t) * noise
|
| return prev_sample
|
|
|
| def ddim_step(self, model_output, timestep, sample, eta=0.0, prev_timestep=None):
|
| """
|
| DDIM-style deterministic sampling step. eta=0.0 for DDIM, eta=1.0 for DDPM-like behavior.
|
| """
|
| if prev_timestep is None:
|
|
|
| alpha_bar_t = self.alphas_cumprod[timestep]
|
| pred_original_sample = (sample - torch.sqrt(1 - alpha_bar_t) * model_output) / torch.sqrt(alpha_bar_t)
|
| pred_original_sample = torch.clamp(pred_original_sample, -1.0, 1.0)
|
| return pred_original_sample
|
|
|
| t = timestep
|
| prev_t = prev_timestep
|
|
|
| alpha_bar_t = self.alphas_cumprod[t]
|
| alpha_bar_prev = self.alphas_cumprod[prev_t]
|
|
|
|
|
| pred_original_sample = (sample - torch.sqrt(1 - alpha_bar_t) * model_output) / torch.sqrt(alpha_bar_t)
|
| pred_original_sample = torch.clamp(pred_original_sample, -1.0, 1.0)
|
|
|
|
|
| sigma_t = eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar_t) * (1 - alpha_bar_t / alpha_bar_prev))
|
|
|
|
|
| pred_sample_direction = torch.sqrt(1 - alpha_bar_prev - sigma_t**2) * model_output
|
|
|
|
|
| prev_sample = torch.sqrt(alpha_bar_prev) * pred_original_sample + pred_sample_direction
|
|
|
|
|
| if eta > 0:
|
| noise = torch.randn_like(model_output)
|
| prev_sample = prev_sample + sigma_t * noise
|
|
|
| return prev_sample
|
| def dpm_solver_multistep(self, model_output, timestep, sample, order=2, prev_timestep=None, prev_model_output=None):
|
| if prev_timestep is None:
|
|
|
| alpha_bar_t = self.alphas_cumprod[timestep]
|
| pred_original_sample = (sample - torch.sqrt(1 - alpha_bar_t) * model_output) / torch.sqrt(alpha_bar_t)
|
| return torch.clamp(pred_original_sample, -1.0, 1.0)
|
|
|
| t = timestep
|
| prev_t = prev_timestep
|
|
|
| alpha_bar_t = self.alphas_cumprod[t]
|
| alpha_bar_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.alphas_cumprod.new_tensor(1.0)
|
|
|
| pred_original_sample = (sample - torch.sqrt(1 - alpha_bar_t) * model_output) / torch.sqrt(alpha_bar_t)
|
| pred_original_sample = torch.clamp(pred_original_sample, -1.0, 1.0)
|
|
|
| if order == 1 or prev_model_output is None:
|
| prev_sample = torch.sqrt(alpha_bar_prev) * pred_original_sample + torch.sqrt(1 - alpha_bar_prev) * model_output
|
| else:
|
| lambda_t = 0.5 * torch.log(alpha_bar_t / (1 - alpha_bar_t))
|
| lambda_prev = 0.5 * torch.log(alpha_bar_prev / (1 - alpha_bar_prev))
|
| h = lambda_prev - lambda_t
|
|
|
| prev_sample = (
|
| torch.sqrt(alpha_bar_prev) * pred_original_sample +
|
| torch.sqrt(1 - alpha_bar_prev) * (
|
| model_output + h * (model_output - prev_model_output) / 2
|
| )
|
| )
|
|
|
| return prev_sample |