File size: 6,216 Bytes
e63569d
 
 
 
 
 
 
 
 
 
 
 
8c50d16
 
 
e63569d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c50d16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""Rebuild AnimeName weak labels from each stored filename."""

from __future__ import annotations

import argparse
import json
from collections import Counter
from datetime import datetime, timezone
from pathlib import Path
from statistics import mean
from typing import Iterable

from tools.dmhy_dataset import weak_label_filename
from anifilebert.label_repairs import repair_jsonl_item
from anifilebert.tokenizer import AnimeTokenizer


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Relabel a JSONL dataset from filename strings")
    parser.add_argument("--input", required=True, help="Input JSONL containing filename fields")
    parser.add_argument("--output", required=True, help="Output relabeled regex-token JSONL")
    parser.add_argument("--manifest-output", default=None, help="Relabel manifest JSON")
    parser.add_argument("--vocab-output", default=None, help="Optional regex vocab JSON")
    parser.add_argument("--base-vocab", default=None, help="Optional regex vocab whose IDs should be preserved")
    parser.add_argument("--max-vocab-size", type=int, default=3000)
    parser.add_argument("--limit", type=int, default=None)
    parser.add_argument("--progress", type=int, default=50000)
    parser.add_argument("--example-count", type=int, default=20)
    return parser.parse_args()


def iter_jsonl(path: Path) -> Iterable[dict]:
    with path.open("r", encoding="utf-8") as handle:
        for line_no, line in enumerate(handle, 1):
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError as exc:
                raise ValueError(f"{path}:{line_no}: invalid JSON") from exc


def length_stats(values: list[int]) -> dict:
    if not values:
        return {"min": 0, "mean": 0, "p50": 0, "p90": 0, "p95": 0, "p99": 0, "max": 0}
    ordered = sorted(values)

    def percentile(pct: float) -> int:
        index = min(len(ordered) - 1, round((pct / 100) * (len(ordered) - 1)))
        return ordered[index]

    return {
        "min": min(values),
        "mean": mean(values),
        "p50": percentile(50),
        "p90": percentile(90),
        "p95": percentile(95),
        "p99": percentile(99),
        "max": max(values),
    }


def main() -> None:
    args = parse_args()
    input_path = Path(args.input)
    output_path = Path(args.output)
    manifest_path = Path(args.manifest_output) if args.manifest_output else output_path.with_suffix(".manifest.json")
    vocab_path = Path(args.vocab_output) if args.vocab_output else None

    output_path.parent.mkdir(parents=True, exist_ok=True)
    manifest_path.parent.mkdir(parents=True, exist_ok=True)
    if vocab_path:
        vocab_path.parent.mkdir(parents=True, exist_ok=True)

    tokenizer = AnimeTokenizer()
    rows_in = 0
    rows_written = 0
    rows_failed = 0
    rows_repaired_after_relabel = 0
    label_counter: Counter[str] = Counter()
    failure_counter: Counter[str] = Counter()
    token_lists: list[list[str]] = []
    lengths: list[int] = []
    examples: list[dict] = []
    failures: list[dict] = []

    with output_path.open("w", encoding="utf-8", newline="\n") as out:
        for item in iter_jsonl(input_path):
            rows_in += 1
            filename = item.get("filename")
            if not filename:
                rows_failed += 1
                failure_counter["missing_filename"] += 1
                continue
            sample = weak_label_filename(str(filename), tokenizer)
            if sample is None:
                rows_failed += 1
                failure_counter["weak_label_failed"] += 1
                if len(failures) < args.example_count:
                    failures.append({"file_id": item.get("file_id"), "filename": filename})
                continue
            record = dict(item)
            record.pop("tokenizer_variant", None)
            record.pop("source_token_count", None)
            record.pop("char_token_count", None)
            record["tokens"] = sample["tokens"]
            record["labels"] = sample["labels"]

            repaired, repairs = repair_jsonl_item(record)
            if repairs:
                rows_repaired_after_relabel += 1
                record = repaired

            out.write(json.dumps(record, ensure_ascii=False, separators=(",", ":")) + "\n")
            rows_written += 1
            label_counter.update(record["labels"])
            token_lists.append(record["tokens"])
            lengths.append(len(record["tokens"]))
            if len(examples) < args.example_count:
                examples.append(record)

            if args.limit is not None and rows_written >= args.limit:
                break
            if args.progress and rows_written % args.progress == 0:
                print(f"relabeled {rows_written:,} rows; failed={rows_failed:,}")

    base_vocab = None
    if args.base_vocab:
        with Path(args.base_vocab).open("r", encoding="utf-8") as handle:
            base_vocab = json.load(handle)
    tokenizer.build_vocab(token_lists, max_size=args.max_vocab_size, base_vocab=base_vocab)
    if vocab_path:
        vocab_path.write_text(json.dumps(tokenizer.get_vocab(), ensure_ascii=False, indent=2) + "\n", encoding="utf-8")

    manifest = {
        "created_at": datetime.now(timezone.utc).isoformat(),
        "input": str(input_path),
        "output": str(output_path),
        "vocab_output": str(vocab_path) if vocab_path else None,
        "row_count": rows_written,
        "input_rows": rows_in,
        "failed_rows": rows_failed,
        "repaired_after_relabel_rows": rows_repaired_after_relabel,
        "failure_counts": dict(failure_counter),
        "label_counts": dict(label_counter),
        "token_length": length_stats(lengths),
        "vocab_size": tokenizer.vocab_size,
        "examples": examples,
        "failures": failures,
    }
    manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
    print(json.dumps({k: v for k, v in manifest.items() if k not in {"examples", "failures"}}, ensure_ascii=False, indent=2))


if __name__ == "__main__":
    main()