|
|
|
|
|
|
|
|
import math |
|
|
import torch |
|
|
|
|
|
|
|
|
def project( |
|
|
v0: torch.Tensor, |
|
|
v1: torch.Tensor, |
|
|
): |
|
|
dtype = v0.dtype |
|
|
v0, v1 = v0.double(), v1.double() |
|
|
v1 = torch.nn.functional.normalize(v1, dim=[-1,-2]) |
|
|
v0_parallel = (v0 * v1).sum(dim=[-1,-2], keepdim=True) * v1 |
|
|
v0_orthogonal = v0 - v0_parallel |
|
|
return v0_parallel.to(dtype), v0_orthogonal.to(dtype) |
|
|
|
|
|
|
|
|
def normalized_guidance( |
|
|
pred_cond: torch. Tensor, |
|
|
pred_uncond: torch.Tensor, |
|
|
guidance_scale: float, |
|
|
momentum_buffer: None, |
|
|
eta: float = 1.0, |
|
|
norm_threshold: float = 0.0, |
|
|
): |
|
|
B, seq_len, dim = pred_cond.shape |
|
|
diff = pred_cond - pred_uncond |
|
|
if norm_threshold > 0: |
|
|
ones = torch.ones_like(diff) |
|
|
diff_norm = 1/math.sqrt(seq_len*dim) * diff.norm(p=2, dim=[-1, -2], keepdim=True) |
|
|
scale_factor = torch.minimum(ones, norm_threshold / diff_norm) |
|
|
diff = diff * scale_factor |
|
|
diff_parallel, diff_orthogonal = project(diff, pred_cond) |
|
|
normalized_update = diff_orthogonal + eta * diff_parallel |
|
|
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update |
|
|
return pred_guided |