github-actions[bot]
Sync from GitHub 33c12db74322f3d28409b5dc0a8c441914c9178b
e0552b0
from einops import rearrange
import torch
from torch import Tensor
from jaxtyping import Float, Int, Bool
from typing import Optional
from koja_diffuser.train.debug import Debug
from koja_diffuser.tokenizer.special import SpecialToken
class MMDLoss:
@staticmethod
def flatten_latent(z: Float[Tensor, "B L S"]) -> Float[Tensor, "B LS"]:
return rearrange(z, "b l s -> b (l s)")
@staticmethod
def pairwise_sq_dist(x: Tensor, y: Tensor) -> Tensor:
x_norm = (x**2).sum(dim=1, keepdim=True)
y_norm = (y**2).sum(dim=1, keepdim=True).transpose(0, 1)
dist = x_norm + y_norm - 2.0 * (x @ y.transpose(0, 1))
return dist.clamp_min(0.0)
@staticmethod
@torch.no_grad()
def estimate_bandwidth(
x: Float[Tensor, "B LS"], y: Float[Tensor, "B LS"], eps=1e-4
) -> Tensor:
z = torch.cat([x, y], dim=0)
d2 = MMDLoss.pairwise_sq_dist(z, z)
mask = d2 > 0
if mask.any():
sigma = d2[mask].median().sqrt() # median heuristic
return sigma.clamp_min(eps)
return z.new_tensor(1.0)
@staticmethod
def off_diagonal_mean(k: Tensor) -> Tensor:
n = k.size(0)
if n <= 1:
return k.mean()
mask = ~torch.eye(n, dtype=torch.bool, device=k.device)
return k[mask].mean()
@staticmethod
def mmd_rbf_loss(
z_fake: Float[Tensor, "B L S"], z_real: Float[Tensor, "B L S"]
) -> Tensor:
x = MMDLoss.flatten_latent(z_fake.float())
y = MMDLoss.flatten_latent(z_real.detach().float())
sigma = MMDLoss.estimate_bandwidth(x.detach(), y.detach())
sigmas = [sigma * 0.5, sigma, sigma * 2.0]
d_xx = MMDLoss.pairwise_sq_dist(x, x)
d_yy = MMDLoss.pairwise_sq_dist(y, y)
d_xy = MMDLoss.pairwise_sq_dist(x, y)
k_xx = torch.zeros_like(d_xx)
k_yy = torch.zeros_like(d_yy)
k_xy = torch.zeros_like(d_xy)
for s in sigmas:
denom = 2.0 * (s**2)
k_xx += torch.exp(-d_xx / denom)
k_yy += torch.exp(-d_yy / denom)
k_xy += torch.exp(-d_xy / denom)
k_xx /= len(sigmas)
k_yy /= len(sigmas)
k_xy /= len(sigmas)
return (
MMDLoss.off_diagonal_mean(k_xx)
+ MMDLoss.off_diagonal_mean(k_yy)
- 2.0 * k_xy.mean()
)
@staticmethod
def direct_domain_loss(
*,
z_ja_hat: Tensor,
z_ko_hat: Tensor,
z_ja: Tensor,
z_ko: Tensor,
d: Optional[Debug] = None,
):
loss_domain_ja = MMDLoss.mmd_rbf_loss(z_ja_hat, z_ja)
loss_domain_ko = MMDLoss.mmd_rbf_loss(z_ko_hat, z_ko)
total = loss_domain_ja + loss_domain_ko
if d is not None:
d.loss.domain_ja(loss_domain_ja)
d.loss.domain_ko(loss_domain_ko)
d.loss.domain_total(total)
return total
def repeat_penalty_loss(
logits: Tensor,
*,
temperature=1.0,
exclude_token_ids: tuple[int, ...] = (SpecialToken.eos,),
) -> Tensor:
probs = (logits.float() / temperature).softmax(dim=-1)
if exclude_token_ids:
probs = probs.clone()
for token_id in exclude_token_ids:
probs[..., token_id] = 0.0
probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(1e-8)
p_prev = probs[:, :-1, :]
p_next = probs[:, 1:, :]
repeat_prob = (p_prev * p_next).sum(dim=-1)
return repeat_prob.mean()
class CenterOneSepLoss:
def __init__(self, logits: Float[Tensor, "B L S"], temperature=1.0):
self.logits = logits
self.probs = (self.logits.float() / temperature).softmax(dim=-1)
self.device = logits.device
def get_len(self) -> Int[Tensor, "B"]:
# EOS 제외 길이
pred_tokens = self.probs.argmax(dim=-1)
is_eos = pred_tokens == SpecialToken.eos
has_eos = is_eos.any(dim=-1)
eos_indices = is_eos.int().argmax(dim=-1)
seq_len = self.probs.size(1)
return torch.where(
has_eos, eos_indices, torch.tensor(seq_len, device=self.device)
)
def mirror_tensor(
self, probs_len: Int[Tensor, "B"]
) -> tuple[Float[Tensor, "B L"], Bool[Tensor, "B L"]]:
max_len = self.logits.size(1)
pos = rearrange(torch.arange(max_len, device=self.device), "l -> 1 l")
mirror = torch.minimum(pos, probs_len[:, None] - 1 - pos)
peak = (probs_len[:, None] - 1) // 2
denom = peak.clamp_min(1)
out = mirror.float() / denom.float()
mask = pos < probs_len[:, None]
out = out.masked_fill(~mask, 0.0)
return out, mask
def loss(self):
probs_len = self.get_len()
sep_probs = self.probs[..., SpecialToken.sep]
mirror_tensor, mask = self.mirror_tensor(probs_len)
pad_mass = sep_probs.masked_fill(~mask, 0.0)
center_loss = (pad_mass * (1.0 - mirror_tensor)).sum(dim=-1).mean()
# Count Loss
expected_pad_count = pad_mass.sum(dim=-1)
target_count = (probs_len > 0).float()
count_loss = ((expected_pad_count - target_count) ** 2).mean()
# Peak Loss
max_pad_prob = pad_mass.max(dim=-1).values
peak_loss = ((max_pad_prob - target_count) ** 2).mean()
return center_loss, count_loss, peak_loss