subnet32-llm-detector / scripts /run_ada_jsonl.py
ThaoTran7's picture
incomplete commit
485127c
#!/usr/bin/env python3
"""
Run AdaDetectGPT criterion on each row of a canonical JSONL (id, label, text, split).
Writes JSONL with ada_score, latency_ms, optional error.
"""
from __future__ import annotations
import argparse
import json
import os
import random
import sys
import time
import numpy as np
import torch
from torch import nn
_SCRIPTS = os.path.join(os.path.dirname(os.path.abspath(__file__)))
if _SCRIPTS not in sys.path:
sys.path.insert(0, _SCRIPTS)
from detect_gpt_ada import get_classification_stat # noqa: E402
from model import load_model, load_tokenizer, model_max_length # noqa: E402
from nuisance_func import BSplineTwoSample # noqa: E402
from nuisance_func_human import BSplineTheory # noqa: E402
from utils import load_training_data, separated_string # noqa: E402
def _load_jsonl(path: str):
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
yield json.loads(line)
def _build_w_func(args, scoring_tokenizer, scoring_model, device, score_max):
if args.w_func == "identity":
return nn.Identity(), None
bspline_args = args.config
args.device = device
if args.w_func == "pretrained":
w_func = BSplineTwoSample(bspline_args, device)
w_func.beta_hat = torch.tensor(
[0.0, -0.011333, -0.037667, -0.056667, -0.281667, -0.592, 0.157833, 0.727333],
device=device,
)
elif args.w_func == "bspline":
print(f"Datasets for learning BSpline: {args.train_dataset}")
train_data = load_training_data(args.train_dataset)
human_token_list = [
scoring_tokenizer(
x,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=score_max,
).to(device)
for x in train_data["original"]
]
machine_token_list = [
scoring_tokenizer(
x,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=score_max,
).to(device)
for x in train_data["sampled"]
]
w_func = BSplineTwoSample(bspline_args, device)
w_func.fit(human_token_list, machine_token_list, scoring_model, args)
elif args.w_func == "bspline_theory":
print(f"Datasets for learning BSpline: {args.train_dataset}")
train_data = load_training_data(args.train_dataset)
human_token_list = [
scoring_tokenizer(
x,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=score_max,
).to(device)
for x in train_data["original"]
]
w_func = BSplineTheory(bspline_args, machine_text=False)
w_func.fit(human_token_list, None, scoring_model, args)
else:
raise ValueError(f"Unknown w_func {args.w_func}")
beta = w_func.beta_hat.detach().cpu().tolist()
return w_func, beta
def _crit_one(
text: str,
scoring_tokenizer,
scoring_model,
sampling_tokenizer,
sampling_model,
encode_max,
w_func,
shift_value,
burn_in_num: int,
device,
):
tokenized = scoring_tokenizer(
text,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=encode_max,
).to(device)
labels = tokenized.input_ids[:, 1:]
with torch.no_grad():
logits_score = scoring_model(**tokenized).logits[:, :-1]
if sampling_model is scoring_model:
logits_ref = logits_score
else:
tokenized_s = sampling_tokenizer(
text,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=encode_max,
).to(device)
assert torch.all(tokenized_s.input_ids[:, 1:] == labels), "Tokenizer mismatch for sampling model."
logits_ref = sampling_model(**tokenized_s).logits[:, :-1]
if burn_in_num > 0:
logits_ref = logits_ref[:, burn_in_num:, :]
logits_score = logits_score[:, burn_in_num:, :]
labels = labels[:, burn_in_num:]
return get_classification_stat(logits_ref, logits_score, labels, w_func, shift_value)
def main():
p = argparse.ArgumentParser()
p.add_argument("--input_jsonl", required=True)
p.add_argument("--output_jsonl", required=True)
p.add_argument("--train_dataset", type=separated_string, default="./exp_main/data/xsum_gpt2-xl&./exp_main/data/writing_gpt2-xl")
p.add_argument("--sampling_model_name", type=str, default="gpt2-xl")
p.add_argument("--scoring_model_name", type=str, default="gpt2-xl")
p.add_argument("--burn_in", type=float, default=0.0)
p.add_argument("--w_func", type=str, default="pretrained", choices=["identity", "pretrained", "bspline", "bspline_theory"])
p.add_argument("--config", type=json.loads, default='{"start": -32, "end": 0, "n_bases": 7, "spline_order": 2, "intercept": 1}')
p.add_argument("--cache_dir", type=str, default="./cache")
p.add_argument("--seed", type=int, default=2025)
args = p.parse_args()
device_str = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_str)
args.device = device_str
print(f"Using device: {device_str}")
scoring_tokenizer = load_tokenizer(args.scoring_model_name, args.cache_dir)
scoring_model = load_model(args.scoring_model_name, device_str, args.cache_dir)
scoring_model.eval()
score_max = model_max_length(scoring_model)
if args.sampling_model_name != args.scoring_model_name:
sampling_tokenizer = load_tokenizer(args.sampling_model_name, args.cache_dir)
sampling_model = load_model(args.sampling_model_name, device_str, args.cache_dir)
sampling_model.eval()
encode_max = min(score_max, model_max_length(sampling_model))
else:
sampling_tokenizer = scoring_tokenizer
sampling_model = scoring_model
encode_max = score_max
w_func, beta = _build_w_func(args, scoring_tokenizer, scoring_model, device, score_max)
shift_value = torch.zeros(1, device=device)
if args.burn_in < 1.0:
burn_in_template = None
else:
burn_in_template = int(args.burn_in)
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
out_dir = os.path.dirname(os.path.abspath(args.output_jsonl))
if out_dir:
os.makedirs(out_dir, exist_ok=True)
with open(args.output_jsonl, "w", encoding="utf-8") as fout:
for row in _load_jsonl(args.input_jsonl):
text = row.get("text", "")
out = dict(row)
out["ada_score"] = None
out["latency_ms"] = None
out["ada_error"] = None
t0 = time.perf_counter()
try:
labels_tok = scoring_tokenizer(
text,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=encode_max,
).to(device)
lab = labels_tok.input_ids[:, 1:]
if args.burn_in < 1.0:
burn_in_num = int(float(args.burn_in) * lab.size(-1))
else:
burn_in_num = burn_in_template
crit = _crit_one(
text,
scoring_tokenizer,
scoring_model,
sampling_tokenizer,
sampling_model,
encode_max,
w_func,
shift_value,
burn_in_num,
device,
)
out["ada_score"] = float(crit)
except Exception as e:
out["ada_error"] = str(e)
out["latency_ms"] = (time.perf_counter() - t0) * 1000.0
fout.write(json.dumps(out, ensure_ascii=False) + "\n")
print(f"Wrote {args.output_jsonl}")
if __name__ == "__main__":
main()