| |
| """ |
| Run AdaDetectGPT criterion on each row of a canonical JSONL (id, label, text, split). |
| Writes JSONL with ada_score, latency_ms, optional error. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import random |
| import sys |
| import time |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
|
|
| _SCRIPTS = os.path.join(os.path.dirname(os.path.abspath(__file__))) |
| if _SCRIPTS not in sys.path: |
| sys.path.insert(0, _SCRIPTS) |
|
|
| from detect_gpt_ada import get_classification_stat |
| from model import load_model, load_tokenizer, model_max_length |
| from nuisance_func import BSplineTwoSample |
| from nuisance_func_human import BSplineTheory |
| from utils import load_training_data, separated_string |
|
|
|
|
| def _load_jsonl(path: str): |
| with open(path, encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| yield json.loads(line) |
|
|
|
|
| def _build_w_func(args, scoring_tokenizer, scoring_model, device, score_max): |
| if args.w_func == "identity": |
| return nn.Identity(), None |
|
|
| bspline_args = args.config |
| args.device = device |
|
|
| if args.w_func == "pretrained": |
| w_func = BSplineTwoSample(bspline_args, device) |
| w_func.beta_hat = torch.tensor( |
| [0.0, -0.011333, -0.037667, -0.056667, -0.281667, -0.592, 0.157833, 0.727333], |
| device=device, |
| ) |
| elif args.w_func == "bspline": |
| print(f"Datasets for learning BSpline: {args.train_dataset}") |
| train_data = load_training_data(args.train_dataset) |
| human_token_list = [ |
| scoring_tokenizer( |
| x, |
| return_tensors="pt", |
| padding=True, |
| return_token_type_ids=False, |
| truncation=True, |
| max_length=score_max, |
| ).to(device) |
| for x in train_data["original"] |
| ] |
| machine_token_list = [ |
| scoring_tokenizer( |
| x, |
| return_tensors="pt", |
| padding=True, |
| return_token_type_ids=False, |
| truncation=True, |
| max_length=score_max, |
| ).to(device) |
| for x in train_data["sampled"] |
| ] |
| w_func = BSplineTwoSample(bspline_args, device) |
| w_func.fit(human_token_list, machine_token_list, scoring_model, args) |
| elif args.w_func == "bspline_theory": |
| print(f"Datasets for learning BSpline: {args.train_dataset}") |
| train_data = load_training_data(args.train_dataset) |
| human_token_list = [ |
| scoring_tokenizer( |
| x, |
| return_tensors="pt", |
| padding=True, |
| return_token_type_ids=False, |
| truncation=True, |
| max_length=score_max, |
| ).to(device) |
| for x in train_data["original"] |
| ] |
| w_func = BSplineTheory(bspline_args, machine_text=False) |
| w_func.fit(human_token_list, None, scoring_model, args) |
| else: |
| raise ValueError(f"Unknown w_func {args.w_func}") |
|
|
| beta = w_func.beta_hat.detach().cpu().tolist() |
| return w_func, beta |
|
|
|
|
| def _crit_one( |
| text: str, |
| scoring_tokenizer, |
| scoring_model, |
| sampling_tokenizer, |
| sampling_model, |
| encode_max, |
| w_func, |
| shift_value, |
| burn_in_num: int, |
| device, |
| ): |
| tokenized = scoring_tokenizer( |
| text, |
| return_tensors="pt", |
| padding=True, |
| return_token_type_ids=False, |
| truncation=True, |
| max_length=encode_max, |
| ).to(device) |
| labels = tokenized.input_ids[:, 1:] |
| with torch.no_grad(): |
| logits_score = scoring_model(**tokenized).logits[:, :-1] |
| if sampling_model is scoring_model: |
| logits_ref = logits_score |
| else: |
| tokenized_s = sampling_tokenizer( |
| text, |
| return_tensors="pt", |
| padding=True, |
| return_token_type_ids=False, |
| truncation=True, |
| max_length=encode_max, |
| ).to(device) |
| assert torch.all(tokenized_s.input_ids[:, 1:] == labels), "Tokenizer mismatch for sampling model." |
| logits_ref = sampling_model(**tokenized_s).logits[:, :-1] |
| if burn_in_num > 0: |
| logits_ref = logits_ref[:, burn_in_num:, :] |
| logits_score = logits_score[:, burn_in_num:, :] |
| labels = labels[:, burn_in_num:] |
| return get_classification_stat(logits_ref, logits_score, labels, w_func, shift_value) |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--input_jsonl", required=True) |
| p.add_argument("--output_jsonl", required=True) |
| p.add_argument("--train_dataset", type=separated_string, default="./exp_main/data/xsum_gpt2-xl&./exp_main/data/writing_gpt2-xl") |
| p.add_argument("--sampling_model_name", type=str, default="gpt2-xl") |
| p.add_argument("--scoring_model_name", type=str, default="gpt2-xl") |
| p.add_argument("--burn_in", type=float, default=0.0) |
| p.add_argument("--w_func", type=str, default="pretrained", choices=["identity", "pretrained", "bspline", "bspline_theory"]) |
| p.add_argument("--config", type=json.loads, default='{"start": -32, "end": 0, "n_bases": 7, "spline_order": 2, "intercept": 1}') |
| p.add_argument("--cache_dir", type=str, default="./cache") |
| p.add_argument("--seed", type=int, default=2025) |
| args = p.parse_args() |
|
|
| device_str = "cuda" if torch.cuda.is_available() else "cpu" |
| device = torch.device(device_str) |
| args.device = device_str |
| print(f"Using device: {device_str}") |
|
|
| scoring_tokenizer = load_tokenizer(args.scoring_model_name, args.cache_dir) |
| scoring_model = load_model(args.scoring_model_name, device_str, args.cache_dir) |
| scoring_model.eval() |
| score_max = model_max_length(scoring_model) |
| if args.sampling_model_name != args.scoring_model_name: |
| sampling_tokenizer = load_tokenizer(args.sampling_model_name, args.cache_dir) |
| sampling_model = load_model(args.sampling_model_name, device_str, args.cache_dir) |
| sampling_model.eval() |
| encode_max = min(score_max, model_max_length(sampling_model)) |
| else: |
| sampling_tokenizer = scoring_tokenizer |
| sampling_model = scoring_model |
| encode_max = score_max |
|
|
| w_func, beta = _build_w_func(args, scoring_tokenizer, scoring_model, device, score_max) |
| shift_value = torch.zeros(1, device=device) |
|
|
| if args.burn_in < 1.0: |
| burn_in_template = None |
| else: |
| burn_in_template = int(args.burn_in) |
|
|
| random.seed(args.seed) |
| torch.manual_seed(args.seed) |
| np.random.seed(args.seed) |
|
|
| out_dir = os.path.dirname(os.path.abspath(args.output_jsonl)) |
| if out_dir: |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| with open(args.output_jsonl, "w", encoding="utf-8") as fout: |
| for row in _load_jsonl(args.input_jsonl): |
| text = row.get("text", "") |
| out = dict(row) |
| out["ada_score"] = None |
| out["latency_ms"] = None |
| out["ada_error"] = None |
| t0 = time.perf_counter() |
| try: |
| labels_tok = scoring_tokenizer( |
| text, |
| return_tensors="pt", |
| padding=True, |
| return_token_type_ids=False, |
| truncation=True, |
| max_length=encode_max, |
| ).to(device) |
| lab = labels_tok.input_ids[:, 1:] |
| if args.burn_in < 1.0: |
| burn_in_num = int(float(args.burn_in) * lab.size(-1)) |
| else: |
| burn_in_num = burn_in_template |
| crit = _crit_one( |
| text, |
| scoring_tokenizer, |
| scoring_model, |
| sampling_tokenizer, |
| sampling_model, |
| encode_max, |
| w_func, |
| shift_value, |
| burn_in_num, |
| device, |
| ) |
| out["ada_score"] = float(crit) |
| except Exception as e: |
| out["ada_error"] = str(e) |
| out["latency_ms"] = (time.perf_counter() - t0) * 1000.0 |
| fout.write(json.dumps(out, ensure_ascii=False) + "\n") |
|
|
| print(f"Wrote {args.output_jsonl}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|