File size: 5,407 Bytes
485127c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env python3
"""Fine-tune a small transformer for binary human vs AI on JSONL (id, label, text)."""

import argparse
import json
import os
import sys

import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)

class JsonlDataset(Dataset):
    def __init__(self, path: str, tokenizer, max_length: int):
        self.rows = []
        with open(path, encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                r = json.loads(line)
                self.rows.append(r)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, i):
        r = self.rows[i]
        text = r.get("text", "")
        enc = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )
        item = {k: v.squeeze(0) for k, v in enc.items()}
        item["labels"] = torch.tensor(int(r["label"]), dtype=torch.long)
        return item


def load_jsonl_rows(path: str):
    rows = []
    with open(path, encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))
    return rows


def main():
    p = argparse.ArgumentParser(
        description="Train a small transformer on JSONL, or run inference only with --predict_jsonl.",
    )
    p.add_argument("--train_jsonl", help="Training JSONL (required unless predict-only)")
    p.add_argument("--val_jsonl", help="Optional validation for Trainer eval")
    p.add_argument("--model_name", default="roberta-base")
    p.add_argument("--max_length", type=int, default=256)
    p.add_argument("--output_dir", required=True, help="Save dir after training, or load dir in predict-only")
    p.add_argument("--epochs", type=float, default=2.0)
    p.add_argument("--batch_size", type=int, default=8)
    p.add_argument("--lr", type=float, default=2e-5)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument(
        "--predict_jsonl",
        help="Inference-only: input JSONL (no --train_jsonl needed). Requires --predict_output.",
    )
    p.add_argument("--predict_output", help="Inference-only: output JSONL with sup_score")
    p.add_argument("--device", default=None, help="cuda:0 or cpu (default: auto)")
    args = p.parse_args()

    predict_only = bool(args.predict_jsonl)
    if predict_only:
        if not args.predict_output:
            raise SystemExit("Predict-only mode requires --predict_output.")
    elif not args.train_jsonl:
        raise SystemExit("Training requires --train_jsonl, or use predict-only: --predict_jsonl and --predict_output.")

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")

    if predict_only:
        tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
        model = AutoModelForSequenceClassification.from_pretrained(args.output_dir)
        model.to(device)
        model.eval()
        rows = load_jsonl_rows(args.predict_jsonl)
        out_path = args.predict_output
        out_d = os.path.dirname(os.path.abspath(out_path))
        if out_d:
            os.makedirs(out_d, exist_ok=True)
        with open(out_path, "w", encoding="utf-8") as fout:
            for r in rows:
                text = r.get("text", "")
                enc = tokenizer(
                    text,
                    truncation=True,
                    max_length=args.max_length,
                    padding="max_length",
                    return_tensors="pt",
                ).to(device)
                with torch.no_grad():
                    logits = model(**enc).logits
                    prob = torch.softmax(logits, dim=-1)[0, 1].item()
                o = dict(r)
                o["sup_score"] = float(prob)
                fout.write(json.dumps(o, ensure_ascii=False) + "\n")
        print(f"Wrote {out_path}")
        return

    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelForSequenceClassification.from_pretrained(args.model_name, num_labels=2)
    train_ds = JsonlDataset(args.train_jsonl, tokenizer, args.max_length)
    eval_ds = JsonlDataset(args.val_jsonl, tokenizer, args.max_length) if args.val_jsonl else None

    os.makedirs(args.output_dir, exist_ok=True)
    targs = TrainingArguments(
        output_dir=args.output_dir,
        learning_rate=args.lr,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        num_train_epochs=args.epochs,
        weight_decay=0.01,
        logging_steps=50,
        save_strategy="epoch",
        evaluation_strategy="no",
        seed=args.seed,
    )
    trainer = Trainer(
        model=model,
        args=targs,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        tokenizer=tokenizer,
    )
    trainer.train()
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    print(f"Saved model to {args.output_dir}")


if __name__ == "__main__":
    main()