| import random |
| import torch |
| from torch.utils.data import DataLoader, TensorDataset |
| from utils.utils import sample_categorical_logits |
| import numpy as np |
| from tqdm import tqdm |
| import torch.distributed as dist |
| import torch.nn.functional as F |
|
|
| def to_one_hot(x_idx, num_classes=4): |
| oh = F.one_hot(x_idx.long(), num_classes=num_classes) |
| return oh.float() |
|
|
| def rnd(model, reward_model, batch_size, scale=1, device='cuda:0'): |
| r""" |
| Run random order sampling and compute the RND $\log\frac{dP^*}{dP^u}$ along the trajectory |
| reward_model: r(X) |
| |
| return: |
| - x: the final samples, [B, D] |
| - log_rnd: the log RND along this trajectory, [B] |
| """ |
| if hasattr(model, 'module'): |
| model = model.module |
| |
| x = torch.full((batch_size, model.length), model.vocab_size-1).to(device=device, dtype=torch.int64) |
| batch_arange = torch.arange(batch_size, device=device) |
| jump_pos = torch.rand(x.shape, device=device).argsort(dim=-1) |
| |
| |
| |
| log_rnd = torch.zeros(batch_size, device=device) |
| for d in range(model.length-1, -1, -1): |
| |
| logits = model(x)[:, :, :-1] |
| update = sample_categorical_logits( |
| logits[batch_arange, jump_pos[:, d]]) |
| if torch.is_grad_enabled(): |
| x = x.clone() |
| x[batch_arange, jump_pos[:, d]] = update |
| log_rnd += -np.log(model.vocab_size-1) - logits[batch_arange, jump_pos[:, d], update] |
| log_rnd += scale * reward_model(x) |
| return x, log_rnd |
|
|
|
|
| @torch.no_grad() |
| def sampling(model, batch_size, rounds=1, device='cuda:0'): |
| """Any order autoregressive sampling""" |
| if hasattr(model, 'module'): |
| model = model.module |
| batch_arange = torch.arange(batch_size, device=device) |
| all_samples = [] |
| for _ in tqdm(range(rounds), leave=False): |
| x = torch.full((batch_size, model.length), model.vocab_size-1).to(device=device, dtype=torch.int64) |
| jump_pos = torch.rand(x.shape, device=device).argsort(dim=-1) |
| |
| |
| |
| for d in tqdm(range(model.length-1, -1, -1), leave=False): |
| |
| logits = model.logits(x)[:, :, :-1] |
| update = sample_categorical_logits( |
| logits[batch_arange, jump_pos[:, d]]) |
| x[batch_arange, jump_pos[:, d]] = update |
| all_samples.append(x) |
| return torch.cat(all_samples) |
|
|
|
|
| def loss_ce(log_rnd): |
| """Cross entropy loss KL(P^*||P^u)""" |
| weights = log_rnd.detach().softmax(dim=-1) |
| return (log_rnd * weights).sum() |
|
|
|
|
| def loss_lv(log_rnd): |
| r"""Log variance loss Var_{P^\bar{u}}\log\frac{dP^*}{dP^u}""" |
| return log_rnd.var() |
|
|
|
|
| def loss_re_rf(log_rnd, const=0): |
| r"""Relative entropy loss KL(P^u||P^*) with REINFORCE trick""" |
| return (-log_rnd * (-log_rnd.detach() + const)).mean() |
|
|
|
|
| def loss_wdce(policy_model, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False): |
| r""" |
| Weighted denoising cross entropy loss |
| X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X) |
| |
| log_rnd: [B]; x: [B, L] (no mask) |
| num_replicates: R, number of replicates of each row in x |
| weight_func: w(lambda) for each sample, 1/lambda by default |
| """ |
| mask_index = policy_model.mask_index |
| if hasattr(policy_model, 'module'): |
| policy_model = policy_model.module |
| |
| batch = x.repeat_interleave(num_replicates, dim=0) |
| |
| batch_weights = log_rnd.detach_().softmax(dim=-1) |
| if centering: |
| batch_weights = batch_weights - batch_weights.mean(dim=-1, keepdim=True) |
| |
| batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0) |
| |
| lamda = torch.rand(batch.shape[0], device=batch.device) |
| lamda_weights = weight_func(lamda).clamp(max=1e5) |
| |
| masked_index = torch.rand(*batch.shape, device=batch.device) < lamda[..., None] |
| perturbed_batch = torch.where(masked_index, mask_index, batch) |
| |
| |
| t = lamda |
| sigma_t = -torch.log1p(-(1 - eps) * t) |
| attn_mask = torch.ones_like(perturbed_batch).to(policy_model.device) |
| |
| |
| logits = policy_model(perturbed_batch, attn_mask=attn_mask, sigma=sigma_t) |
| losses = torch.zeros(*batch.shape, device=batch.device, dtype=logits.dtype) |
| losses[masked_index] = torch.gather(input=logits[masked_index], dim=-1, |
| index=batch[masked_index][..., None]).squeeze(-1) |
| return - (losses.sum(dim=-1) * lamda_weights * batch_weights).mean() |
|
|
|
|
| def loss_dce(model, x, weight_func=lambda l: 1/l): |
| r""" |
| Denoising cross entropy loss, x [B, D] are ground truth samples |
| weight_func: w(lambda) for each sample, 1/lambda by default |
| """ |
| lamda = torch.rand(x.shape[0], device=x.device) |
| lamda_weights = weight_func(lamda).clamp(max=1e5) |
| masked_index = torch.rand(*x.shape, device=x.device) < lamda[..., None] |
| perturbed_batch = torch.where(masked_index, model.vocab_size-1, x) |
| logits = model(perturbed_batch) |
| losses = torch.zeros(*x.shape, device=x.device, dtype=logits.dtype) |
| losses[masked_index] = torch.gather(input=logits[masked_index], dim=-1, |
| index=x[masked_index][..., None]).squeeze(-1) |
| return - (losses.sum(dim=-1) * lamda_weights).mean() |