File size: 7,528 Bytes
485127c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | import torch
from BSpline import BSpline
from torch import Tensor
from torch import nn
from tqdm import tqdm
class BSplineTwoSample(nn.Module):
def __init__(self, bspline_args, device):
super().__init__()
self.bspline = BSpline(**bspline_args)
self.bspline = self.bspline.to(device)
pass
def inv_sqrt_matrix(self, M: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
# 1) eigen-decompose
if M.device.type == "mps":
M_cpu = M.detach().to("cpu", dtype=torch.float32)
w, Q = torch.linalg.eigh(M_cpu)
w = w.to(M.device)
Q = Q.to(M.device)
else:
w, Q = torch.linalg.eigh(M) # w: (d,), Q: (d,d)
# 2) take inverse square-root of eigenvalues
w_inv_sqrt = (w.clamp(min=eps) ** -0.5) # (d,)
D_inv_sqrt = torch.diag(w_inv_sqrt) # (d,d)
# 3) reconstruct
return Q @ D_inv_sqrt @ Q.T # (d,d)
def solve_beta_star(self, A: torch.Tensor, c: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
"""
Solves
max_{β} c^T β / sqrt(β^T A β)
s.t. d^T β = 0,
returning the optimizer β* of unit A‐norm (i.e. β*^T A β* = 1).
A: (d, d) symmetric positive‐definite
c, d: (d,)
"""
REGULARIZATION = 0.01
# Solve B u = a
alpha = REGULARIZATION * torch.eye(A.shape[0], device=A.device, dtype=A.dtype) # Regularization term
# Solve A v = c and A w = d (both on the same device & dtype)
v = torch.linalg.solve(A+alpha, c) # shape (d,)
w = torch.linalg.solve(A+alpha, d) # shape (d,)
# Compute μ = (d^T A^{-1} c) / (d^T A^{-1} d) = (d·v)/(d·w)
mu = torch.dot(d, v) / torch.dot(d, w)
# Un‐normalized solution β₀ = A⁻¹(c – μ d) = v – μ w
beta0 = v - mu * w # shape (d,)
# Normalize so that β*^T A β* = 1
# First compute norm² = β₀ᵀ A β₀
norm_sq = torch.dot(beta0, (A+alpha) @ beta0)
norm = torch.sqrt(norm_sq)
beta_star = beta0 / norm
return beta_star
def compute_beta_hat(
self,
z_u_list, # human texts: list[torch.Tensor], each of shape (1, Li)
z_v_list, # LLM texts: list[torch.Tensor], each of shape (1, Lj)
constraint,
) -> torch.Tensor:
device = z_u_list[0].device
d = self.bspline.n_bases
if self.bspline.add_intercept:
d = d + 1
# 1) Gather lengths and flatten all the z's at once
u_lengths = [z.shape[-1] for z in z_u_list]
v_lengths = [z.shape[-1] for z in z_v_list]
# stack all the tokens into one long 1D tensor
all_u = torch.cat([z.squeeze(0).clamp_min(self.bspline.start) for z in z_u_list], dim=0).to(device)
all_v = torch.cat([z.squeeze(0).clamp_min(self.bspline.start) for z in z_v_list], dim=0).to(device)
# 2) Compute B‑spline basis in one go
all_u_feats = self.bspline(all_u) # shape = (sum(u_lengths), d)
all_v_feats = self.bspline(all_v) # shape = (sum(v_lengths), d)
# 3) Split back into per‑sequence tensors
# torch.split is implemented in C, so it's very cheap
u_feats = list(torch.split(all_u_feats, u_lengths, dim=0))
v_feats = list(torch.split(all_v_feats, v_lengths, dim=0))
# 4) Compute u_means and v_means
# stacked means, shape = (n_u, d) and (n_v, d)
u_means = torch.stack([f.mean(dim=0) for f in u_feats], dim=0)
v_means = torch.stack([f.mean(dim=0) for f in v_feats], dim=0)
# 5) Build delta
delta = v_means.sum(dim=0) - u_means.sum(dim=0) # (d,)
# --- 3) Covariances Σ_u, Σ_v over ALL tokens in each sequence ---
Sigma_u = torch.zeros((d, d), device=device) # (d, d)
for i, Fu in enumerate(u_feats):
Fu_c = Fu - Fu.mean(dim=0, keepdim=True)
Sigma_u += ((Fu_c.T @ Fu_c) / (Fu_c.shape[0] - 1)) / Fu.shape[0]
Sigma_v = torch.zeros((d, d), device=device)
for i, Fv in enumerate(v_feats):
Fv_c = Fv - Fv.mean(dim=0, keepdim=True)
Sigma_v += ((Fv_c.T @ Fv_c) / (Fv_c.shape[0] - 1)) / Fv.shape[0]
Sigma = Sigma_u + Sigma_v # (d, d)
# --- 4) Closed-form beta = Σ^{-1} δ, then normalize ---
if constraint:
beta_hat = self.solve_beta_star(Sigma, delta, u_means.sum(dim=0))
else:
Sigma = self.inv_sqrt_matrix(Sigma)
beta_tilde = Sigma @ delta # (d,)
beta_hat = beta_tilde / beta_tilde.norm(p=2)
return beta_hat
def get_zij(self, token_list, model, args):
model.eval()
n_samples = len(token_list)
z_list = []
for idx in tqdm(range(n_samples)):
tokenized = token_list[idx]
labels = tokenized.input_ids[:, 1:]
with torch.no_grad():
logits_score = model(**tokenized).logits[:, :-1]
labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels
z_j_b = torch.log_softmax(logits_score, dim=-1)
z_j = z_j_b.gather(dim=-1, index=labels).squeeze(-1)
z_list.append(z_j)
return z_list
def fit(self, human_token_list, machine_token_list, model, args, constraint=False):
print("Learning witness function...")
print("Fetch log-likelihood of human texts...")
z_ij_u = self.get_zij(human_token_list, model, args)
print("Fetch log-likelihood of LLM texts...")
z_ij_v = self.get_zij(machine_token_list, model, args)
beta_hat = self.compute_beta_hat(z_ij_u, z_ij_v, constraint)
self.beta_hat = beta_hat
print("beta_hat:", torch.round(beta_hat, decimals=3))
def forward(self, input: Tensor):
input_shape = input.shape
device = input.device
flat = input.clamp_min(self.bspline.start).reshape(-1).to(device)
w_value = self.bspline(flat) @ self.beta_hat
w_value = w_value.reshape(input_shape)
return w_value
def get_ci_list(text_list, tokenizer, model, w_fun, args):
model.eval()
n_samples = len(text_list)
c_list = []
for idx in tqdm(range(n_samples)):
original_text = text_list[idx]
tokenized = tokenizer(original_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
labels = tokenized.input_ids[:, 1:]
with torch.no_grad():
logits_score = model(**tokenized).logits[:, :-1]
labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels
z_j_b = w_fun(torch.log_softmax(logits_score, dim=-1))
probs_ref = torch.softmax(logits_score, dim=-1)
mean_ref = (probs_ref * z_j_b).sum(dim=-1)
z_j = z_j_b.gather(dim=-1, index=labels).squeeze(-1)
ci = (z_j.mean(dim=-1) - mean_ref.mean(dim=-1))[0]
c_list.append(ci)
return c_list
class ShiftLearner(nn.Module):
def __init__(self):
super().__init__()
pass
def fit(self, data, tokenizer, model, w_func, args):
print("Learning shift...")
ci_hat_list = get_ci_list(data['original'], tokenizer, model, w_func, args)
c_hat = torch.mean(torch.tensor(ci_hat_list))
self.c_hat = c_hat
print("c_hat:", torch.round(c_hat, decimals=3))
|