File size: 3,428 Bytes
4d939fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Tree generation CLI utilities for DETree."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Iterable, Sequence, Set

from detree.utils.dataset import load_datapath, model_alias_mapping


def _str2bool(value: str) -> bool:
    """Parse common textual boolean representations used by legacy scripts."""

    if isinstance(value, bool):
        return value
    lowered = value.lower()
    if lowered in {"true", "1", "yes", "y"}:
        return True
    if lowered in {"false", "0", "no", "n"}:
        return False
    raise argparse.ArgumentTypeError(f"Boolean value expected, got: {value}")


def get_data_model(data_path: Iterable[Path], has_mix: bool = True) -> Set[str]:
    """Collect all model identifiers present in the provided dataset paths."""

    llm_name: Set[str] = set()
    cnt = 0
    for path in data_path:
        print(f"reading {path}")
        with path.open(mode="r", encoding="utf-8") as jsonl_file:
            for line in jsonl_file:
                now = json.loads(line)
                if now["src"] not in model_alias_mapping:
                    model_alias_mapping[now["src"]] = now["src"]
                now["src"] = model_alias_mapping[now["src"]]
                if not has_mix and "human" in now["src"] and now["src"] != "human":
                    continue
                if now["src"] not in llm_name:
                    llm_name.add(now["src"])
                cnt += 1
    print(cnt)
    return llm_name


def build_argument_parser() -> argparse.ArgumentParser:
    """Create the argument parser for the tree generation CLI."""

    parser = argparse.ArgumentParser(
        description="Generate DETree-compatible tree definitions from dataset files.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("--path", type=Path, default=Path("/opt/AI-text-Dataset"), help="Root directory of the dataset.")
    parser.add_argument("--dataset_name", type=str, default="all", help="Dataset configuration name.")
    parser.add_argument(
        "--mode",
        type=str,
        choices=("train", "test", "extra"),
        default="train",
        help="Dataset split to consume.",
    )
    parser.add_argument("--tree_txt", type=Path, default=Path("output/Tree_RAID_pcl.txt"), help="Output tree definition path.")
    parser.add_argument("--adversarial", type=_str2bool, default=True, help="Whether to include adversarial data splits.")
    parser.add_argument("--has_mix", type=_str2bool, default=True, help="Whether to keep mixed human/model generations.")
    return parser


def main(args: argparse.Namespace) -> None:
    """Entry point for building DETree-compatible tree structures."""

    dataset_paths: Sequence[str] = load_datapath(args.path, args.adversarial, args.dataset_name)[args.mode]
    print(f"data_path: {dataset_paths}")
    llm_name = sorted(get_data_model((Path(p) for p in dataset_paths), args.has_mix))
    root = len(llm_name)
    args.tree_txt.parent.mkdir(parents=True, exist_ok=True)
    with args.tree_txt.open("w", encoding="utf-8") as f:
        for i, item in enumerate(llm_name):
            f.write(f"{i} {root} {item}\n")
        f.write(f"{root} -1 none\n")


if __name__ == "__main__":
    parser = build_argument_parser()
    main(parser.parse_args())