Spaces:
Sleeping
Sleeping
| from einops import rearrange | |
| import torch | |
| from torch import Tensor | |
| from jaxtyping import Float, Int, Bool | |
| from typing import Optional | |
| from koja_diffuser.train.debug import Debug | |
| from koja_diffuser.tokenizer.special import SpecialToken | |
| class MMDLoss: | |
| def flatten_latent(z: Float[Tensor, "B L S"]) -> Float[Tensor, "B LS"]: | |
| return rearrange(z, "b l s -> b (l s)") | |
| def pairwise_sq_dist(x: Tensor, y: Tensor) -> Tensor: | |
| x_norm = (x**2).sum(dim=1, keepdim=True) | |
| y_norm = (y**2).sum(dim=1, keepdim=True).transpose(0, 1) | |
| dist = x_norm + y_norm - 2.0 * (x @ y.transpose(0, 1)) | |
| return dist.clamp_min(0.0) | |
| def estimate_bandwidth( | |
| x: Float[Tensor, "B LS"], y: Float[Tensor, "B LS"], eps=1e-4 | |
| ) -> Tensor: | |
| z = torch.cat([x, y], dim=0) | |
| d2 = MMDLoss.pairwise_sq_dist(z, z) | |
| mask = d2 > 0 | |
| if mask.any(): | |
| sigma = d2[mask].median().sqrt() # median heuristic | |
| return sigma.clamp_min(eps) | |
| return z.new_tensor(1.0) | |
| def off_diagonal_mean(k: Tensor) -> Tensor: | |
| n = k.size(0) | |
| if n <= 1: | |
| return k.mean() | |
| mask = ~torch.eye(n, dtype=torch.bool, device=k.device) | |
| return k[mask].mean() | |
| def mmd_rbf_loss( | |
| z_fake: Float[Tensor, "B L S"], z_real: Float[Tensor, "B L S"] | |
| ) -> Tensor: | |
| x = MMDLoss.flatten_latent(z_fake.float()) | |
| y = MMDLoss.flatten_latent(z_real.detach().float()) | |
| sigma = MMDLoss.estimate_bandwidth(x.detach(), y.detach()) | |
| sigmas = [sigma * 0.5, sigma, sigma * 2.0] | |
| d_xx = MMDLoss.pairwise_sq_dist(x, x) | |
| d_yy = MMDLoss.pairwise_sq_dist(y, y) | |
| d_xy = MMDLoss.pairwise_sq_dist(x, y) | |
| k_xx = torch.zeros_like(d_xx) | |
| k_yy = torch.zeros_like(d_yy) | |
| k_xy = torch.zeros_like(d_xy) | |
| for s in sigmas: | |
| denom = 2.0 * (s**2) | |
| k_xx += torch.exp(-d_xx / denom) | |
| k_yy += torch.exp(-d_yy / denom) | |
| k_xy += torch.exp(-d_xy / denom) | |
| k_xx /= len(sigmas) | |
| k_yy /= len(sigmas) | |
| k_xy /= len(sigmas) | |
| return ( | |
| MMDLoss.off_diagonal_mean(k_xx) | |
| + MMDLoss.off_diagonal_mean(k_yy) | |
| - 2.0 * k_xy.mean() | |
| ) | |
| def direct_domain_loss( | |
| *, | |
| z_ja_hat: Tensor, | |
| z_ko_hat: Tensor, | |
| z_ja: Tensor, | |
| z_ko: Tensor, | |
| d: Optional[Debug] = None, | |
| ): | |
| loss_domain_ja = MMDLoss.mmd_rbf_loss(z_ja_hat, z_ja) | |
| loss_domain_ko = MMDLoss.mmd_rbf_loss(z_ko_hat, z_ko) | |
| total = loss_domain_ja + loss_domain_ko | |
| if d is not None: | |
| d.loss.domain_ja(loss_domain_ja) | |
| d.loss.domain_ko(loss_domain_ko) | |
| d.loss.domain_total(total) | |
| return total | |
| def repeat_penalty_loss( | |
| logits: Tensor, | |
| *, | |
| temperature=1.0, | |
| exclude_token_ids: tuple[int, ...] = (SpecialToken.eos,), | |
| ) -> Tensor: | |
| probs = (logits.float() / temperature).softmax(dim=-1) | |
| if exclude_token_ids: | |
| probs = probs.clone() | |
| for token_id in exclude_token_ids: | |
| probs[..., token_id] = 0.0 | |
| probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(1e-8) | |
| p_prev = probs[:, :-1, :] | |
| p_next = probs[:, 1:, :] | |
| repeat_prob = (p_prev * p_next).sum(dim=-1) | |
| return repeat_prob.mean() | |
| class CenterOneSepLoss: | |
| def __init__(self, logits: Float[Tensor, "B L S"], temperature=1.0): | |
| self.logits = logits | |
| self.probs = (self.logits.float() / temperature).softmax(dim=-1) | |
| self.device = logits.device | |
| def get_len(self) -> Int[Tensor, "B"]: | |
| # EOS 제외 길이 | |
| pred_tokens = self.probs.argmax(dim=-1) | |
| is_eos = pred_tokens == SpecialToken.eos | |
| has_eos = is_eos.any(dim=-1) | |
| eos_indices = is_eos.int().argmax(dim=-1) | |
| seq_len = self.probs.size(1) | |
| return torch.where( | |
| has_eos, eos_indices, torch.tensor(seq_len, device=self.device) | |
| ) | |
| def mirror_tensor( | |
| self, probs_len: Int[Tensor, "B"] | |
| ) -> tuple[Float[Tensor, "B L"], Bool[Tensor, "B L"]]: | |
| max_len = self.logits.size(1) | |
| pos = rearrange(torch.arange(max_len, device=self.device), "l -> 1 l") | |
| mirror = torch.minimum(pos, probs_len[:, None] - 1 - pos) | |
| peak = (probs_len[:, None] - 1) // 2 | |
| denom = peak.clamp_min(1) | |
| out = mirror.float() / denom.float() | |
| mask = pos < probs_len[:, None] | |
| out = out.masked_fill(~mask, 0.0) | |
| return out, mask | |
| def loss(self): | |
| probs_len = self.get_len() | |
| sep_probs = self.probs[..., SpecialToken.sep] | |
| mirror_tensor, mask = self.mirror_tensor(probs_len) | |
| pad_mass = sep_probs.masked_fill(~mask, 0.0) | |
| center_loss = (pad_mass * (1.0 - mirror_tensor)).sum(dim=-1).mean() | |
| # Count Loss | |
| expected_pad_count = pad_mass.sum(dim=-1) | |
| target_count = (probs_len > 0).float() | |
| count_loss = ((expected_pad_count - target_count) ** 2).mean() | |
| # Peak Loss | |
| max_pad_prob = pad_mass.max(dim=-1).values | |
| peak_loss = ((max_pad_prob - target_count) ** 2).mean() | |
| return center_loss, count_loss, peak_loss | |