ModerRAS commited on
Commit
651ad49
·
1 Parent(s): 1e1bc1f

Add balanced focus dataset generator

Browse files
Files changed (1) hide show
  1. tools/build_balanced_focus_dataset.py +199 -0
tools/build_balanced_focus_dataset.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build a balanced focus set from real parser failures and nearby DMHY rows.
2
+
3
+ The goal is to repair boundary mistakes without teaching the model that every
4
+ special-like token should dominate title/season/episode context. Reported
5
+ failures are resolved back to their authoritative char BIO rows from DMHY when
6
+ possible, then mixed with repaired rows, broad boundary-pattern rows, random
7
+ context, and a small number of deterministic hard cases.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import json
14
+ import random
15
+ import re
16
+ from collections import Counter
17
+ from pathlib import Path
18
+ from typing import Iterable, Sequence
19
+
20
+ from anifilebert.label_repairs import repair_jsonl_item
21
+ from tools.build_path_focus_dataset import build_cases as build_path_cases
22
+ from tools.build_repair_focus_dataset import manual_cases as repair_manual_cases
23
+
24
+
25
+ BOUNDARY_FOCUS_RE = re.compile(
26
+ r"(?ix)"
27
+ r"(?:"
28
+ r"\b(?:NCOP|NCED|OP|ED|PV|CM|TVCM|OVA|OAD|SP|Menu)\s*[_\-.]?\s*(?:\d{0,4}|ep\.?\s*\d{1,4}|ver\.?\s*\d{1,2})\b|"
29
+ r"\b(?:Blu[-_ ]?ray\s*&\s*DVD|BD[-_ ]?BOX|Disc\.?\s*\d+|Vol\.?\s*\d+)\b|"
30
+ r"\b(?:S\d{1,2}[_\-.]?(?:OP|ED|NCOP|NCED)|NC(?:OP|ED)\d+[_\-.]?\d+)\b|"
31
+ r"\b(?:II|III|IV|V|Ⅱ|Ⅲ|Ⅳ|Ⅴ)\s+(?:OVA|OAD|CM|PV|OP|ED|Menu)\b|"
32
+ r"(?:弐|貳|贰|二|三|參|叁|参)\s*(?:ノ|の|之)\s*(?:章|期|季|部)|"
33
+ r"第\s*(?:\d+|[一二三四五六七八九十兩两貳贰弐弍參叁参肆伍陸陆柒捌玖]+)\s*[季期部章]|"
34
+ r"\b(?:Act|Part)\s+(?:II|III|IV|V)\b|"
35
+ r"\b(?:h\.?264|x\.?264|h\.?265|x\.?265|AVC[-_ ]?YUV|yuv\d+p?\d*|AAC\([^)]*\))\b"
36
+ r")"
37
+ )
38
+
39
+
40
+ def parse_args() -> argparse.Namespace:
41
+ parser = argparse.ArgumentParser(description=__doc__)
42
+ parser.add_argument("--input", required=True, help="Authoritative char JSONL dataset")
43
+ parser.add_argument("--output", required=True, help="Output focus JSONL")
44
+ parser.add_argument(
45
+ "--failure-report",
46
+ action="append",
47
+ default=[],
48
+ help="Parse/case metrics JSON with failures to resolve back to DMHY rows",
49
+ )
50
+ parser.add_argument("--context-samples", type=int, default=70000)
51
+ parser.add_argument("--max-boundary-rows", type=int, default=90000)
52
+ parser.add_argument("--repeat-failure", type=int, default=18)
53
+ parser.add_argument("--repeat-repaired", type=int, default=2)
54
+ parser.add_argument("--repeat-boundary", type=int, default=1)
55
+ parser.add_argument("--repeat-manual", type=int, default=8)
56
+ parser.add_argument("--repeat-path", type=int, default=24)
57
+ parser.add_argument("--seed", type=int, default=42)
58
+ return parser.parse_args()
59
+
60
+
61
+ def iter_jsonl(path: Path) -> Iterable[dict]:
62
+ with path.open("r", encoding="utf-8") as handle:
63
+ for line in handle:
64
+ line = line.strip()
65
+ if line:
66
+ yield json.loads(line)
67
+
68
+
69
+ def reservoir_add(rows: list[dict], item: dict, limit: int, rng: random.Random, seen_count: int) -> None:
70
+ if limit <= 0:
71
+ return
72
+ if len(rows) < limit:
73
+ rows.append(item)
74
+ return
75
+ index = rng.randrange(seen_count)
76
+ if index < limit:
77
+ rows[index] = item
78
+
79
+
80
+ def failure_filenames(report_paths: Sequence[str]) -> set[str]:
81
+ filenames: set[str] = set()
82
+ for value in report_paths:
83
+ path = Path(value)
84
+ if not path.exists():
85
+ continue
86
+ report = json.loads(path.read_text(encoding="utf-8"))
87
+ modes = report.get("modes", {})
88
+ for mode in modes.values():
89
+ if not isinstance(mode, dict):
90
+ continue
91
+ for failure in mode.get("failures", []):
92
+ filename = failure.get("filename")
93
+ if filename:
94
+ filenames.add(str(filename))
95
+ for result in mode.get("results", []):
96
+ if result.get("ok", True):
97
+ continue
98
+ filename = result.get("filename")
99
+ if filename:
100
+ filenames.add(str(filename))
101
+ return filenames
102
+
103
+
104
+ def clone_with_source(item: dict, source: str) -> dict:
105
+ cloned = dict(item)
106
+ cloned["source"] = source
107
+ return cloned
108
+
109
+
110
+ def main() -> None:
111
+ args = parse_args()
112
+ rng = random.Random(args.seed)
113
+ input_path = Path(args.input)
114
+ output_path = Path(args.output)
115
+
116
+ targets = failure_filenames(args.failure_report)
117
+ failure_rows: list[dict] = []
118
+ repaired_rows: list[dict] = []
119
+ boundary_rows: list[dict] = []
120
+ context_rows: list[dict] = []
121
+ seen_filenames: set[str] = set()
122
+ source_counts: Counter[str] = Counter()
123
+ total_rows = 0
124
+ boundary_seen = 0
125
+ context_seen = 0
126
+
127
+ for item in iter_jsonl(input_path):
128
+ total_rows += 1
129
+ filename = str(item.get("filename") or "")
130
+ if not filename:
131
+ continue
132
+
133
+ if filename in targets and filename not in seen_filenames:
134
+ failure_rows.append(clone_with_source(item, "balanced_report_failure"))
135
+ seen_filenames.add(filename)
136
+ continue
137
+
138
+ _repaired_item, repairs = repair_jsonl_item(item)
139
+ if repairs and filename not in seen_filenames:
140
+ repaired_rows.append(clone_with_source(item, "balanced_repaired_context"))
141
+ seen_filenames.add(filename)
142
+ continue
143
+
144
+ if BOUNDARY_FOCUS_RE.search(filename) and filename not in seen_filenames:
145
+ boundary_seen += 1
146
+ reservoir_add(
147
+ boundary_rows,
148
+ clone_with_source(item, "balanced_boundary_pattern"),
149
+ args.max_boundary_rows,
150
+ rng,
151
+ boundary_seen,
152
+ )
153
+ seen_filenames.add(filename)
154
+ continue
155
+
156
+ if filename in seen_filenames:
157
+ continue
158
+ context_seen += 1
159
+ reservoir_add(context_rows, clone_with_source(item, "balanced_random_context"), args.context_samples, rng, context_seen)
160
+
161
+ rows: list[dict] = []
162
+ rows.extend(failure_rows * max(1, args.repeat_failure))
163
+ rows.extend(repaired_rows * max(1, args.repeat_repaired))
164
+ rows.extend(boundary_rows * max(1, args.repeat_boundary))
165
+ rows.extend(context_rows)
166
+ for item in repair_manual_cases():
167
+ rows.extend([clone_with_source(item, "balanced_manual_repair")] * max(1, args.repeat_manual))
168
+ for item in build_path_cases("balanced_manual_path"):
169
+ rows.extend([item] * max(1, args.repeat_path))
170
+
171
+ rng.shuffle(rows)
172
+ output_path.parent.mkdir(parents=True, exist_ok=True)
173
+ with output_path.open("w", encoding="utf-8") as handle:
174
+ for item in rows:
175
+ handle.write(json.dumps(item, ensure_ascii=False, separators=(",", ":")) + "\n")
176
+ source_counts[str(item.get("source", "unknown"))] += 1
177
+
178
+ print(
179
+ json.dumps(
180
+ {
181
+ "input": str(input_path),
182
+ "output": str(output_path),
183
+ "total_rows": total_rows,
184
+ "failure_targets": len(targets),
185
+ "matched_failure_rows": len(failure_rows),
186
+ "repaired_rows": len(repaired_rows),
187
+ "boundary_rows": len(boundary_rows),
188
+ "context_rows": len(context_rows),
189
+ "written_rows": len(rows),
190
+ "source_counts": dict(source_counts),
191
+ },
192
+ ensure_ascii=False,
193
+ indent=2,
194
+ )
195
+ )
196
+
197
+
198
+ if __name__ == "__main__":
199
+ main()