subnet32-llm-detector / scripts /nuisance_func_human.py
ThaoTran7's picture
incomplete commit
485127c
import torch
from torch import nn
from tqdm import tqdm
from BSpline import BSpline
def optimal_beta(a: torch.Tensor, B: torch.Tensor, c: torch.Tensor = None) -> torch.Tensor:
REGULARIZATION = 0.01
# Solve B u = a
alpha = REGULARIZATION * torch.eye(B.shape[0], device=B.device, dtype=B.dtype) # Regularization term
u = torch.linalg.solve(B+alpha, a)
if c is None:
w = u
else:
# Solve B v = c
v = torch.linalg.solve(B+alpha, c)
mu = torch.dot(c, u) / torch.dot(c, v)
w = u - mu * v
# Normalize to satisfy w^T B w = 1
norm = torch.sqrt(torch.dot(w, B @ w))
return w / norm
def get_bias_vector_mc(token_list, model, bspline: BSpline, args, compute_cov: bool = False, sample_size: int = 5000):
"""
Monte Carol based method
sample_size: sampling size
"""
model.eval()
device = args.device
n_bases = bspline.n_bases
if bspline.add_intercept:
n_bases += 1
bias_list = []
cov_list = []
for tokens in tqdm(token_list):
input_ids = tokens.input_ids[0]
with torch.no_grad():
logits = model(**tokens).logits[0, :-1] # (seq_len, vocab_size)
seq_len, vocab_size = logits.shape
probs = torch.softmax(logits, dim=-1) # (seq_len, vocab)
logp = torch.log_softmax(logits, dim=-1) # (seq_len, vocab)
labels = input_ids[1:] # (seq_len,)
log_ll = logp[torch.arange(seq_len), labels] # (seq_len,)
w_j = bspline.predict(log_ll.clamp_min(bspline.start).reshape(-1)).to(device)
w_j = w_j.reshape(seq_len, n_bases)
sampled_indices = torch.multinomial(probs, num_samples=sample_size, replacement=True)
sampled_logp = logp.gather(1, sampled_indices)
flat_sampled_logp = sampled_logp.clamp_min(bspline.start).reshape(-1)
flat_basis_samples = bspline.predict(flat_sampled_logp).to(device)
basis_samples = flat_basis_samples.reshape(seq_len, sample_size, n_bases)
mean_ref = basis_samples.mean(dim=1) # (seq_len, n_bases)
bias_sample = (w_j - mean_ref).sum(dim=0) # (n_bases,)
bias_list.append(bias_sample)
if compute_cov:
cov_sample = torch.zeros(n_bases, n_bases, device=device)
for t in range(seq_len):
phi_samples = basis_samples[t]
sample_mean = phi_samples.mean(dim=0)
centered_samples = phi_samples - sample_mean
cov_t = (centered_samples.t() @ centered_samples) / (sample_size - 1)
if bspline.add_intercept:
cov_t[0, :] = 0.0
cov_t[:, 0] = 0.0
cov_sample += cov_t
cov_list.append(cov_sample)
bias_vector = torch.stack(bias_list, dim=0).mean(dim=0) # (n_bases,)
if compute_cov:
cov_matrix = torch.stack(cov_list, dim=0).mean(dim=0) # (n_bases, n_bases)
return bias_vector, cov_matrix
return bias_vector
def get_bias_vector(token_list, model, bspline: BSpline, args, compute_cov: bool = False, speedup_rate = 1):
"""
For each text in text_list, compute the bias vector
(mean difference between sampled basis and expected basis)
and, if requested, the covariance matrix of basis differences.
Returns bias (n_bases,) and optionally cov (n_bases, n_bases).
"""
model.eval()
device = args.device
n_bases = bspline.n_bases
if bspline.add_intercept:
n_bases += 1
bias_list = []
cov_list = []
for tokens in tqdm(token_list):
input_ids = tokens.input_ids[0]
with torch.no_grad():
logits = model(**tokens).logits[0, :-1] # (seq_len, vocab_size)
seq_len, vocab_size = logits.shape
probs = torch.softmax(logits, dim=-1) # (seq_len, vocab)
vocab_size = int(vocab_size / speedup_rate)
probs, _ = torch.topk(probs, k=vocab_size, dim=-1)
probs = probs / probs.sum(dim=-1, keepdim=True)
logp = torch.log_softmax(logits, dim=-1) # (seq_len, vocab)
# Basis at the actual next-token labels
labels = input_ids[1:] # (seq_len,)
# Index basis per position
log_ll = logp[torch.arange(seq_len), labels] # (seq_len, n_bases)
w_j = bspline.predict(log_ll.clamp_min(bspline.start).reshape(-1)).to(device)
w_j = w_j.reshape(seq_len, n_bases)
# Expected basis at the next-token
logp, _ = torch.topk(logp, k=vocab_size, dim=-1)
# Compute basis for all log-probabilities
flat_logp = logp.clamp_min(bspline.start).reshape(-1)
flat_basis = bspline.predict(flat_logp).to(device) # (seq_len*vocab, n_bases)
basis = flat_basis.reshape(seq_len, vocab_size, n_bases) # (seq_len, vocab, n_bases)
# Expected basis per position: E[basis] = sum_k p_{ik} * basis_{ik}
mean_ref = (probs.unsqueeze(-1) * basis).sum(dim=1) # (seq_len, n_bases)
# Bias: average difference over positions
bias_sample = (mean_ref - w_j).sum(dim=0) # (n_bases,)
bias_list.append(bias_sample)
if compute_cov:
###### Naive version ######
cov_sample = torch.zeros(n_bases, n_bases, device=device)
for t in range(seq_len):
p_t = probs[t] # (vocab,)
phi_t = basis[t] # (vocab, n_bases)
# E_b[phi_t]: (n_bases,)
Ex_t = (p_t.unsqueeze(1) * phi_t).sum(dim=0)
# E_b[phi_t phi_t^T]: (n_bases, n_bases)
ExxT_t = phi_t.t() @ (p_t.unsqueeze(1) * phi_t)
cov_t = ExxT_t - Ex_t.unsqueeze(1) @ Ex_t.unsqueeze(0)
cov_sample += cov_t
###### Faster version ######
# wb = probs.unsqueeze(-1) * basis # (T, V, K)
# # 2) compute E[φ φᵀ] summed over t,b:
# # for each t: basis[t].T @ wb[t] → (K, K)
# # sum over t with a single bmm + sum:
# ExxT = torch.sum(torch.bmm(basis.transpose(1,2), wb), dim=0) # (K, K)
# Ex_t = wb.sum(dim=1) # (T, K)
# sum_outer = torch.einsum('tk,tl->kl', Ex_t, Ex_t) # (K, K)
# cov_sample = ExxT - sum_outer
cov_list.append(cov_sample)
bias_vector = torch.stack(bias_list, dim=0).mean(dim=0) # (n_bases,)
if compute_cov:
cov_matrix = torch.stack(cov_list, dim=0).mean(dim=0) # (n_bases, n_bases)
return bias_vector, cov_matrix
return bias_vector
class BSplineTheory(nn.Module):
def __init__(self, bspline_args, machine_text: bool = False):
super().__init__()
self.bspline = BSpline(**bspline_args)
self.machine_text = machine_text
self.beta_hat = None
def fit(self, human_token_list, machine_token_list, model, args):
device = args.device
print("Learning w function...")
print("Fetching bias and covariance for human texts...")
bias_a, cov_B = get_bias_vector(human_token_list, model, self.bspline, args, compute_cov=True)
print("Computing beta_hat...")
# print("bias_a:", torch.round(bias_a, decimals=2))
# print("cov_B:", torch.round(cov_B, decimals=2))
if self.machine_text:
print("Fetching bias for machine-generated texts...")
bias_c = get_bias_vector(machine_token_list, model, self.bspline, args, compute_cov=False)
# print("bias_c:", torch.round(bias_c, decimals=2))
else:
bias_c = None
self.beta_hat = optimal_beta(
bias_a.to(device), cov_B.to(device),
bias_c.to(device) if bias_c is not None else None
)
print("beta_hat:", torch.round(self.beta_hat, decimals=3))
def forward(self, x: torch.Tensor) -> torch.Tensor:
device = x.device
flat = x.clamp_min(self.bspline.start).reshape(-1)
basis = self.bspline.predict(flat).to(device) # (flat_len, n_bases)
w_flat = basis @ self.beta_hat.to(device) # (flat_len,)
return w_flat.reshape(x.shape)