Spaces:
Running
Running
File size: 3,428 Bytes
4d939fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
"""Tree generation CLI utilities for DETree."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Iterable, Sequence, Set
from detree.utils.dataset import load_datapath, model_alias_mapping
def _str2bool(value: str) -> bool:
"""Parse common textual boolean representations used by legacy scripts."""
if isinstance(value, bool):
return value
lowered = value.lower()
if lowered in {"true", "1", "yes", "y"}:
return True
if lowered in {"false", "0", "no", "n"}:
return False
raise argparse.ArgumentTypeError(f"Boolean value expected, got: {value}")
def get_data_model(data_path: Iterable[Path], has_mix: bool = True) -> Set[str]:
"""Collect all model identifiers present in the provided dataset paths."""
llm_name: Set[str] = set()
cnt = 0
for path in data_path:
print(f"reading {path}")
with path.open(mode="r", encoding="utf-8") as jsonl_file:
for line in jsonl_file:
now = json.loads(line)
if now["src"] not in model_alias_mapping:
model_alias_mapping[now["src"]] = now["src"]
now["src"] = model_alias_mapping[now["src"]]
if not has_mix and "human" in now["src"] and now["src"] != "human":
continue
if now["src"] not in llm_name:
llm_name.add(now["src"])
cnt += 1
print(cnt)
return llm_name
def build_argument_parser() -> argparse.ArgumentParser:
"""Create the argument parser for the tree generation CLI."""
parser = argparse.ArgumentParser(
description="Generate DETree-compatible tree definitions from dataset files.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--path", type=Path, default=Path("/opt/AI-text-Dataset"), help="Root directory of the dataset.")
parser.add_argument("--dataset_name", type=str, default="all", help="Dataset configuration name.")
parser.add_argument(
"--mode",
type=str,
choices=("train", "test", "extra"),
default="train",
help="Dataset split to consume.",
)
parser.add_argument("--tree_txt", type=Path, default=Path("output/Tree_RAID_pcl.txt"), help="Output tree definition path.")
parser.add_argument("--adversarial", type=_str2bool, default=True, help="Whether to include adversarial data splits.")
parser.add_argument("--has_mix", type=_str2bool, default=True, help="Whether to keep mixed human/model generations.")
return parser
def main(args: argparse.Namespace) -> None:
"""Entry point for building DETree-compatible tree structures."""
dataset_paths: Sequence[str] = load_datapath(args.path, args.adversarial, args.dataset_name)[args.mode]
print(f"data_path: {dataset_paths}")
llm_name = sorted(get_data_model((Path(p) for p in dataset_paths), args.has_mix))
root = len(llm_name)
args.tree_txt.parent.mkdir(parents=True, exist_ok=True)
with args.tree_txt.open("w", encoding="utf-8") as f:
for i, item in enumerate(llm_name):
f.write(f"{i} {root} {item}\n")
f.write(f"{root} -1 none\n")
if __name__ == "__main__":
parser = build_argument_parser()
main(parser.parse_args())
|