AniFileBERT / tools /enforce_contiguous_title.py
ModerRAS's picture
Add robust LLM relabel pipeline and enforce contiguous title
fed9d99
raw
history blame
5.85 kB
#!/usr/bin/env python3
"""
Enforce a single contiguous TITLE span for every JSONL row.
This script is deterministic and streaming-friendly for very large datasets.
It is intended as a hard safety pass before/alongside LLM relabeling.
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Dict, List, Sequence, Tuple
from anifilebert.label_repairs import repair_jsonl_item
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Force contiguous TITLE spans in JSONL labels")
parser.add_argument("--input", required=True, help="Input JSONL")
parser.add_argument("--output", required=True, help="Output JSONL")
parser.add_argument("--manifest-output", default="", help="Optional manifest JSON")
parser.add_argument("--progress", type=int, default=50000, help="Progress print interval")
return parser.parse_args()
def normalize_iob2(labels: Sequence[str]) -> List[str]:
out: List[str] = []
prev = ""
for lb in labels:
if not isinstance(lb, str) or not lb.startswith(("B-", "I-")):
out.append("O")
prev = ""
continue
entity = lb.split("-", 1)[1]
prefix = "I" if prev == entity else "B"
out.append(f"{prefix}-{entity}")
prev = entity
return out
def is_discontinuous_title(labels: Sequence[str]) -> bool:
seen_title = False
seen_gap = False
for lb in labels:
is_title = isinstance(lb, str) and lb.endswith("TITLE")
if is_title:
if seen_title and seen_gap:
return True
seen_title = True
elif seen_title:
seen_gap = True
return False
def title_segments(labels: Sequence[str]) -> List[Tuple[int, int]]:
segs: List[Tuple[int, int]] = []
i = 0
n = len(labels)
while i < n:
if str(labels[i]).endswith("TITLE"):
j = i + 1
while j < n and str(labels[j]).endswith("TITLE"):
j += 1
segs.append((i, j))
i = j
else:
i += 1
return segs
def first_episode_or_special_index(labels: Sequence[str]) -> int:
for idx, lb in enumerate(labels):
text = str(lb)
if text.endswith("EPISODE") or text.endswith("SPECIAL"):
return idx
return len(labels)
def pick_primary_title_segment(labels: Sequence[str], segs: Sequence[Tuple[int, int]]) -> Tuple[int, int]:
if not segs:
return (-1, -1)
bound = first_episode_or_special_index(labels)
before = [seg for seg in segs if seg[0] < bound]
# Prefer the earliest title span before episode/special boundary.
if before:
return min(before, key=lambda seg: seg[0])
return min(segs, key=lambda seg: seg[0])
def enforce_contiguous_title(labels: Sequence[str]) -> List[str]:
fixed = normalize_iob2(labels)
segs = title_segments(fixed)
if len(segs) <= 1:
return fixed
keep_start, keep_end = pick_primary_title_segment(fixed, segs)
if keep_start < 0:
return fixed
out = list(fixed)
for idx, lb in enumerate(out):
if str(lb).endswith("TITLE") and not (keep_start <= idx < keep_end):
out[idx] = "O"
return normalize_iob2(out)
def main() -> None:
args = parse_args()
input_path = Path(args.input)
output_path = Path(args.output)
manifest_path = Path(args.manifest_output) if args.manifest_output else output_path.with_suffix(".contiguous_title.manifest.json")
output_path.parent.mkdir(parents=True, exist_ok=True)
manifest_path.parent.mkdir(parents=True, exist_ok=True)
rows = 0
changed_rows = 0
bad_before = 0
bad_after = 0
invalid_rows = 0
tmp_path = output_path.with_suffix(output_path.suffix + ".tmp")
with input_path.open("r", encoding="utf-8") as src, tmp_path.open("w", encoding="utf-8", newline="\n") as dst:
for line in src:
line = line.rstrip("\n")
if not line:
continue
rows += 1
rec = json.loads(line)
tokens = rec.get("tokens", [])
labels = rec.get("labels", [])
if not isinstance(tokens, list) or not isinstance(labels, list) or len(tokens) != len(labels):
invalid_rows += 1
dst.write(json.dumps(rec, ensure_ascii=False, separators=(",", ":")) + "\n")
continue
if is_discontinuous_title(labels):
bad_before += 1
new_labels = enforce_contiguous_title(labels)
out_rec: Dict = dict(rec)
out_rec["labels"] = new_labels
repaired, _ = repair_jsonl_item(out_rec)
out_labels = repaired.get("labels", new_labels)
if is_discontinuous_title(out_labels):
bad_after += 1
if out_labels != labels:
changed_rows += 1
repaired["labels"] = out_labels
dst.write(json.dumps(repaired, ensure_ascii=False, separators=(",", ":")) + "\n")
if args.progress > 0 and rows % args.progress == 0:
print(
f"rows={rows} changed={changed_rows} "
f"bad_before={bad_before} bad_after={bad_after} invalid={invalid_rows}"
)
tmp_path.replace(output_path)
manifest = {
"input": str(input_path),
"output": str(output_path),
"rows": rows,
"changed_rows": changed_rows,
"discontinuous_before": bad_before,
"discontinuous_after": bad_after,
"invalid_rows": invalid_rows,
}
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(manifest, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()