#!/usr/bin/env python3 """ Enforce a single contiguous TITLE span for every JSONL row. This script is deterministic and streaming-friendly for very large datasets. It is intended as a hard safety pass before/alongside LLM relabeling. """ from __future__ import annotations import argparse import json from pathlib import Path from typing import Dict, List, Sequence, Tuple from anifilebert.label_repairs import repair_jsonl_item def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Force contiguous TITLE spans in JSONL labels") parser.add_argument("--input", required=True, help="Input JSONL") parser.add_argument("--output", required=True, help="Output JSONL") parser.add_argument("--manifest-output", default="", help="Optional manifest JSON") parser.add_argument("--progress", type=int, default=50000, help="Progress print interval") return parser.parse_args() def normalize_iob2(labels: Sequence[str]) -> List[str]: out: List[str] = [] prev = "" for lb in labels: if not isinstance(lb, str) or not lb.startswith(("B-", "I-")): out.append("O") prev = "" continue entity = lb.split("-", 1)[1] prefix = "I" if prev == entity else "B" out.append(f"{prefix}-{entity}") prev = entity return out def is_discontinuous_title(labels: Sequence[str]) -> bool: seen_title = False seen_gap = False for lb in labels: is_title = isinstance(lb, str) and lb.endswith("TITLE") if is_title: if seen_title and seen_gap: return True seen_title = True elif seen_title: seen_gap = True return False def title_segments(labels: Sequence[str]) -> List[Tuple[int, int]]: segs: List[Tuple[int, int]] = [] i = 0 n = len(labels) while i < n: if str(labels[i]).endswith("TITLE"): j = i + 1 while j < n and str(labels[j]).endswith("TITLE"): j += 1 segs.append((i, j)) i = j else: i += 1 return segs def first_episode_or_special_index(labels: Sequence[str]) -> int: for idx, lb in enumerate(labels): text = str(lb) if text.endswith("EPISODE") or text.endswith("SPECIAL"): return idx return len(labels) def pick_primary_title_segment(labels: Sequence[str], segs: Sequence[Tuple[int, int]]) -> Tuple[int, int]: if not segs: return (-1, -1) bound = first_episode_or_special_index(labels) before = [seg for seg in segs if seg[0] < bound] # Prefer the earliest title span before episode/special boundary. if before: return min(before, key=lambda seg: seg[0]) return min(segs, key=lambda seg: seg[0]) def enforce_contiguous_title(labels: Sequence[str]) -> List[str]: fixed = normalize_iob2(labels) segs = title_segments(fixed) if len(segs) <= 1: return fixed keep_start, keep_end = pick_primary_title_segment(fixed, segs) if keep_start < 0: return fixed out = list(fixed) for idx, lb in enumerate(out): if str(lb).endswith("TITLE") and not (keep_start <= idx < keep_end): out[idx] = "O" return normalize_iob2(out) def main() -> None: args = parse_args() input_path = Path(args.input) output_path = Path(args.output) manifest_path = Path(args.manifest_output) if args.manifest_output else output_path.with_suffix(".contiguous_title.manifest.json") output_path.parent.mkdir(parents=True, exist_ok=True) manifest_path.parent.mkdir(parents=True, exist_ok=True) rows = 0 changed_rows = 0 bad_before = 0 bad_after = 0 invalid_rows = 0 tmp_path = output_path.with_suffix(output_path.suffix + ".tmp") with input_path.open("r", encoding="utf-8") as src, tmp_path.open("w", encoding="utf-8", newline="\n") as dst: for line in src: line = line.rstrip("\n") if not line: continue rows += 1 rec = json.loads(line) tokens = rec.get("tokens", []) labels = rec.get("labels", []) if not isinstance(tokens, list) or not isinstance(labels, list) or len(tokens) != len(labels): invalid_rows += 1 dst.write(json.dumps(rec, ensure_ascii=False, separators=(",", ":")) + "\n") continue if is_discontinuous_title(labels): bad_before += 1 new_labels = enforce_contiguous_title(labels) out_rec: Dict = dict(rec) out_rec["labels"] = new_labels repaired, _ = repair_jsonl_item(out_rec) out_labels = repaired.get("labels", new_labels) if is_discontinuous_title(out_labels): bad_after += 1 if out_labels != labels: changed_rows += 1 repaired["labels"] = out_labels dst.write(json.dumps(repaired, ensure_ascii=False, separators=(",", ":")) + "\n") if args.progress > 0 and rows % args.progress == 0: print( f"rows={rows} changed={changed_rows} " f"bad_before={bad_before} bad_after={bad_after} invalid={invalid_rows}" ) tmp_path.replace(output_path) manifest = { "input": str(input_path), "output": str(output_path), "rows": rows, "changed_rows": changed_rows, "discontinuous_before": bad_before, "discontinuous_after": bad_after, "invalid_rows": invalid_rows, } manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2), encoding="utf-8") print(json.dumps(manifest, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()