| import torch |
| from BSpline import BSpline |
| from torch import Tensor |
| from torch import nn |
| from tqdm import tqdm |
|
|
| class BSplineTwoSample(nn.Module): |
| def __init__(self, bspline_args, device): |
| super().__init__() |
| self.bspline = BSpline(**bspline_args) |
| self.bspline = self.bspline.to(device) |
| pass |
| |
| def inv_sqrt_matrix(self, M: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: |
| |
| if M.device.type == "mps": |
| M_cpu = M.detach().to("cpu", dtype=torch.float32) |
| w, Q = torch.linalg.eigh(M_cpu) |
| w = w.to(M.device) |
| Q = Q.to(M.device) |
| else: |
| w, Q = torch.linalg.eigh(M) |
| |
| w_inv_sqrt = (w.clamp(min=eps) ** -0.5) |
| D_inv_sqrt = torch.diag(w_inv_sqrt) |
| |
| return Q @ D_inv_sqrt @ Q.T |
|
|
| def solve_beta_star(self, A: torch.Tensor, c: torch.Tensor, d: torch.Tensor) -> torch.Tensor: |
| """ |
| Solves |
| max_{β} c^T β / sqrt(β^T A β) |
| s.t. d^T β = 0, |
| returning the optimizer β* of unit A‐norm (i.e. β*^T A β* = 1). |
| |
| A: (d, d) symmetric positive‐definite |
| c, d: (d,) |
| """ |
| REGULARIZATION = 0.01 |
| |
| alpha = REGULARIZATION * torch.eye(A.shape[0], device=A.device, dtype=A.dtype) |
|
|
| |
| v = torch.linalg.solve(A+alpha, c) |
| w = torch.linalg.solve(A+alpha, d) |
|
|
| |
| mu = torch.dot(d, v) / torch.dot(d, w) |
|
|
| |
| beta0 = v - mu * w |
|
|
| |
| |
| norm_sq = torch.dot(beta0, (A+alpha) @ beta0) |
| norm = torch.sqrt(norm_sq) |
|
|
| beta_star = beta0 / norm |
| return beta_star |
|
|
| def compute_beta_hat( |
| self, |
| z_u_list, |
| z_v_list, |
| constraint, |
| ) -> torch.Tensor: |
| device = z_u_list[0].device |
| d = self.bspline.n_bases |
| if self.bspline.add_intercept: |
| d = d + 1 |
|
|
| |
| u_lengths = [z.shape[-1] for z in z_u_list] |
| v_lengths = [z.shape[-1] for z in z_v_list] |
|
|
| |
| all_u = torch.cat([z.squeeze(0).clamp_min(self.bspline.start) for z in z_u_list], dim=0).to(device) |
| all_v = torch.cat([z.squeeze(0).clamp_min(self.bspline.start) for z in z_v_list], dim=0).to(device) |
|
|
| |
| all_u_feats = self.bspline(all_u) |
| all_v_feats = self.bspline(all_v) |
|
|
| |
| |
| u_feats = list(torch.split(all_u_feats, u_lengths, dim=0)) |
| v_feats = list(torch.split(all_v_feats, v_lengths, dim=0)) |
|
|
| |
| |
| u_means = torch.stack([f.mean(dim=0) for f in u_feats], dim=0) |
| v_means = torch.stack([f.mean(dim=0) for f in v_feats], dim=0) |
|
|
| |
| delta = v_means.sum(dim=0) - u_means.sum(dim=0) |
|
|
| |
| Sigma_u = torch.zeros((d, d), device=device) |
| for i, Fu in enumerate(u_feats): |
| Fu_c = Fu - Fu.mean(dim=0, keepdim=True) |
| Sigma_u += ((Fu_c.T @ Fu_c) / (Fu_c.shape[0] - 1)) / Fu.shape[0] |
| Sigma_v = torch.zeros((d, d), device=device) |
| for i, Fv in enumerate(v_feats): |
| Fv_c = Fv - Fv.mean(dim=0, keepdim=True) |
| Sigma_v += ((Fv_c.T @ Fv_c) / (Fv_c.shape[0] - 1)) / Fv.shape[0] |
| Sigma = Sigma_u + Sigma_v |
|
|
| |
| if constraint: |
| beta_hat = self.solve_beta_star(Sigma, delta, u_means.sum(dim=0)) |
| else: |
| Sigma = self.inv_sqrt_matrix(Sigma) |
| beta_tilde = Sigma @ delta |
| beta_hat = beta_tilde / beta_tilde.norm(p=2) |
| return beta_hat |
|
|
| def get_zij(self, token_list, model, args): |
| model.eval() |
|
|
| n_samples = len(token_list) |
| z_list = [] |
| for idx in tqdm(range(n_samples)): |
| tokenized = token_list[idx] |
| labels = tokenized.input_ids[:, 1:] |
| with torch.no_grad(): |
| logits_score = model(**tokenized).logits[:, :-1] |
| labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels |
| z_j_b = torch.log_softmax(logits_score, dim=-1) |
| z_j = z_j_b.gather(dim=-1, index=labels).squeeze(-1) |
| z_list.append(z_j) |
|
|
| return z_list |
|
|
| def fit(self, human_token_list, machine_token_list, model, args, constraint=False): |
| print("Learning witness function...") |
| print("Fetch log-likelihood of human texts...") |
| z_ij_u = self.get_zij(human_token_list, model, args) |
| print("Fetch log-likelihood of LLM texts...") |
| z_ij_v = self.get_zij(machine_token_list, model, args) |
| beta_hat = self.compute_beta_hat(z_ij_u, z_ij_v, constraint) |
| self.beta_hat = beta_hat |
| print("beta_hat:", torch.round(beta_hat, decimals=3)) |
|
|
| def forward(self, input: Tensor): |
| input_shape = input.shape |
| device = input.device |
| flat = input.clamp_min(self.bspline.start).reshape(-1).to(device) |
| w_value = self.bspline(flat) @ self.beta_hat |
| w_value = w_value.reshape(input_shape) |
| return w_value |
|
|
| def get_ci_list(text_list, tokenizer, model, w_fun, args): |
| model.eval() |
|
|
| n_samples = len(text_list) |
| c_list = [] |
| for idx in tqdm(range(n_samples)): |
| original_text = text_list[idx] |
| tokenized = tokenizer(original_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device) |
| labels = tokenized.input_ids[:, 1:] |
| with torch.no_grad(): |
| logits_score = model(**tokenized).logits[:, :-1] |
| labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels |
| z_j_b = w_fun(torch.log_softmax(logits_score, dim=-1)) |
| probs_ref = torch.softmax(logits_score, dim=-1) |
| mean_ref = (probs_ref * z_j_b).sum(dim=-1) |
| z_j = z_j_b.gather(dim=-1, index=labels).squeeze(-1) |
| |
| ci = (z_j.mean(dim=-1) - mean_ref.mean(dim=-1))[0] |
| c_list.append(ci) |
| return c_list |
|
|
| class ShiftLearner(nn.Module): |
| def __init__(self): |
| super().__init__() |
| pass |
|
|
| def fit(self, data, tokenizer, model, w_func, args): |
| print("Learning shift...") |
| ci_hat_list = get_ci_list(data['original'], tokenizer, model, w_func, args) |
| c_hat = torch.mean(torch.tensor(ci_hat_list)) |
| self.c_hat = c_hat |
| print("c_hat:", torch.round(c_hat, decimals=3)) |
|
|