mude / model /scheduler.py
NZUONG's picture
Upload 27 files
1913ec5 verified
# file: model/scheduler.py
import torch
import torch.nn.functional as F
class LinearNoiseScheduler:
def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
self.num_timesteps = num_timesteps
# Tạo lịch beta tuyến tính
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
# Tính toán các giá trị alpha
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
# Các hệ số để thêm nhiễu (forward process)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
# Các hệ số để loại bỏ nhiễu (reverse process / sampling)
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
# Khởi tạo một lịch trình timestep mặc định
self.timesteps = torch.arange(0, num_timesteps).flip(0)
def set_timesteps(self, num_inference_steps, device=None):
"""
Thiết lập các timestep rời rạc được sử dụng cho chuỗi diffusion.
"""
device_to_use = device if device is not None else self.betas.device
self.timesteps = torch.linspace(self.num_timesteps - 1, 0, num_inference_steps, dtype=torch.long, device=device_to_use)
def to(self, device):
"""Chuyển tất cả các tensor của scheduler sang một thiết bị cụ thể."""
self.betas = self.betas.to(device)
self.alphas = self.alphas.to(device)
self.alphas_cumprod = self.alphas_cumprod.to(device)
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
self.alphas_cumprod_prev = self.alphas_cumprod_prev.to(device)
self.posterior_variance = self.posterior_variance.to(device)
return self
def add_noise(self, original_samples, noise, timesteps):
"""Thêm nhiễu vào mẫu gốc tại các bước thời gian t."""
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod.to(timesteps.device)[timesteps].view(-1, 1, 1, 1)
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod.to(timesteps.device)[timesteps].view(-1, 1, 1, 1)
noisy_samples = sqrt_alphas_cumprod_t * original_samples + sqrt_one_minus_alphas_cumprod_t * noise
return noisy_samples
def step(self, model_output, timestep, sample):
t = timestep
alpha_t = self.alphas[t]
alpha_bar_t = self.alphas_cumprod[t]
sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alphas_cumprod[t]
pred_original_sample = (sample - sqrt_one_minus_alpha_bar_t * model_output) / torch.sqrt(alpha_bar_t)
pred_original_sample = torch.clamp(pred_original_sample, -1., 1.)
if t == 0:
return pred_original_sample
alpha_bar_t_prev = self.alphas_cumprod_prev[t]
posterior_variance_t = self.posterior_variance[t]
pred_sample_direction = torch.sqrt(alpha_bar_t_prev) * self.betas[t] / (1. - alpha_bar_t)
prev_sample_mean = torch.sqrt(alpha_t) * (1. - alpha_bar_t_prev) / (1. - alpha_bar_t) * sample + pred_sample_direction * pred_original_sample
noise = torch.randn_like(model_output) if t > 0 else torch.zeros_like(model_output)
prev_sample = prev_sample_mean + torch.sqrt(posterior_variance_t) * noise
return prev_sample
def ddim_step(self, model_output, timestep, sample, eta=0.0, prev_timestep=None):
"""
DDIM-style deterministic sampling step. eta=0.0 for DDIM, eta=1.0 for DDPM-like behavior.
"""
if prev_timestep is None:
# Final step: return x0 prediction
alpha_bar_t = self.alphas_cumprod[timestep]
pred_original_sample = (sample - torch.sqrt(1 - alpha_bar_t) * model_output) / torch.sqrt(alpha_bar_t)
pred_original_sample = torch.clamp(pred_original_sample, -1.0, 1.0)
return pred_original_sample
t = timestep
prev_t = prev_timestep
alpha_bar_t = self.alphas_cumprod[t]
alpha_bar_prev = self.alphas_cumprod[prev_t]
# 1. Compute predicted original sample
pred_original_sample = (sample - torch.sqrt(1 - alpha_bar_t) * model_output) / torch.sqrt(alpha_bar_t)
pred_original_sample = torch.clamp(pred_original_sample, -1.0, 1.0)
# 2. Compute variance for random noise (only effective when eta > 0)
sigma_t = eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar_t) * (1 - alpha_bar_t / alpha_bar_prev))
# 3. Compute "direction pointing to x_t"
pred_sample_direction = torch.sqrt(1 - alpha_bar_prev - sigma_t**2) * model_output
# 4. Compute x_{t-1}
prev_sample = torch.sqrt(alpha_bar_prev) * pred_original_sample + pred_sample_direction
# 5. Add noise (if eta > 0)
if eta > 0:
noise = torch.randn_like(model_output)
prev_sample = prev_sample + sigma_t * noise
return prev_sample
def dpm_solver_multistep(self, model_output, timestep, sample, order=2, prev_timestep=None, prev_model_output=None):
if prev_timestep is None:
# Final step: return x0 prediction
alpha_bar_t = self.alphas_cumprod[timestep]
pred_original_sample = (sample - torch.sqrt(1 - alpha_bar_t) * model_output) / torch.sqrt(alpha_bar_t)
return torch.clamp(pred_original_sample, -1.0, 1.0)
t = timestep
prev_t = prev_timestep
alpha_bar_t = self.alphas_cumprod[t]
alpha_bar_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.alphas_cumprod.new_tensor(1.0)
pred_original_sample = (sample - torch.sqrt(1 - alpha_bar_t) * model_output) / torch.sqrt(alpha_bar_t)
pred_original_sample = torch.clamp(pred_original_sample, -1.0, 1.0)
if order == 1 or prev_model_output is None:
prev_sample = torch.sqrt(alpha_bar_prev) * pred_original_sample + torch.sqrt(1 - alpha_bar_prev) * model_output
else:
lambda_t = 0.5 * torch.log(alpha_bar_t / (1 - alpha_bar_t))
lambda_prev = 0.5 * torch.log(alpha_bar_prev / (1 - alpha_bar_prev))
h = lambda_prev - lambda_t
prev_sample = (
torch.sqrt(alpha_bar_prev) * pred_original_sample +
torch.sqrt(1 - alpha_bar_prev) * (
model_output + h * (model_output - prev_model_output) / 2
)
)
return prev_sample