"""Build annotation units from a DMHY prefix DAG. The DAG producer keeps repeated suffix structure shared across many raw prefixes. This tool turns those shared nodes into compact, traceable units for LLM or human review without calling any remote service. """ from __future__ import annotations import argparse import json from dataclasses import dataclass from pathlib import Path from typing import Any, Iterable from tools.annotate_dmhy_prefix_graph import ( heuristic_patch, string_list, ) DEFAULT_DAG = Path("datasets/AnimeName/dmhy_prefix_dag.json") DEFAULT_OUTPUT = Path("datasets/AnimeName/dmhy_prefix_dag.annotation_units.jsonl") @dataclass(frozen=True) class Args: dag: Path output: Path min_reachable_terminals: int min_incoming_count: int limit: int | None example_count: int def parse_args() -> Args: parser = argparse.ArgumentParser( description="Emit DAG-aware DMHY prefix annotation units as JSONL" ) parser.add_argument("--dag", type=Path, default=DEFAULT_DAG, help="Input dmhy_prefix_dag.json") parser.add_argument( "--output", type=Path, default=DEFAULT_OUTPUT, help="Output annotation unit JSONL", ) parser.add_argument( "--min-reachable-terminals", type=int, default=2, help="Select non-root nodes reaching at least this many terminals", ) parser.add_argument( "--min-incoming-count", type=int, default=2, help="Select nodes with at least this many incoming DAG edges", ) parser.add_argument("--limit", type=int, default=None, help="Maximum units to write") parser.add_argument( "--example-count", type=int, default=5, help="Maximum examples retained per example field", ) ns = parser.parse_args() return Args( dag=ns.dag, output=ns.output, min_reachable_terminals=max(1, ns.min_reachable_terminals), min_incoming_count=max(1, ns.min_incoming_count), limit=ns.limit, example_count=max(1, ns.example_count), ) def load_dag(path: Path) -> dict[str, Any]: if not path.exists(): raise SystemExit(f"DAG not found: {path}") try: dag = json.loads(path.read_text(encoding="utf-8")) except json.JSONDecodeError as exc: raise SystemExit(f"invalid DAG JSON in {path}: {exc}") from exc if not isinstance(dag, dict): raise SystemExit(f"invalid DAG schema in {path}: root must be an object") if not isinstance(dag.get("nodes"), list): raise SystemExit(f"invalid DAG schema in {path}: missing nodes list") if not isinstance(dag.get("terminals"), list): raise SystemExit(f"invalid DAG schema in {path}: missing terminals list") return dag def node_id(node: dict[str, Any], fallback: int) -> int: value = node.get("id", fallback) try: return int(value) except (TypeError, ValueError): raise SystemExit(f"invalid node id: {value!r}") from None def terminal_id(terminal: dict[str, Any], fallback: int) -> str: value = terminal.get("terminal_id", terminal.get("id", fallback)) return str(value) def int_field(row: dict[str, Any], key: str, default: int = 0) -> int: try: return int(row.get(key, default) or default) except (TypeError, ValueError): return default def build_indexes(dag: dict[str, Any]) -> tuple[dict[int, dict[str, Any]], dict[int, list[dict[str, Any]]]]: nodes: dict[int, dict[str, Any]] = {} for fallback, node in enumerate(dag["nodes"]): if not isinstance(node, dict): continue nodes[node_id(node, fallback)] = node terminals_by_node: dict[int, list[dict[str, Any]]] = {} for fallback, terminal in enumerate(dag["terminals"]): if not isinstance(terminal, dict): continue try: terminal_node_id = int(terminal.get("node_id")) except (TypeError, ValueError): continue terminal = dict(terminal) terminal["_terminal_id"] = terminal_id(terminal, fallback) terminal["_terminal_index"] = fallback terminals_by_node.setdefault(terminal_node_id, []).append(terminal) return nodes, terminals_by_node def reachable_terminals( start_node_id: int, nodes: dict[int, dict[str, Any]], terminals_by_node: dict[int, list[dict[str, Any]]], ) -> list[dict[str, Any]]: memo: dict[int, list[dict[str, Any]]] = {} visiting: set[int] = set() def visit(current_id: int) -> list[dict[str, Any]]: if current_id in memo: return memo[current_id] if current_id in visiting: raise SystemExit(f"cycle detected while traversing DAG at node {current_id}") visiting.add(current_id) found = list(terminals_by_node.get(current_id, [])) node = nodes.get(current_id, {}) for edge in node.get("children") or []: if not isinstance(edge, dict): continue try: target = int(edge.get("target")) except (TypeError, ValueError): continue found.extend(visit(target)) visiting.remove(current_id) deduped = dedupe_terminals(found) memo[current_id] = deduped return deduped return visit(start_node_id) def dedupe_terminals(terminals: Iterable[dict[str, Any]]) -> list[dict[str, Any]]: seen: set[str] = set() result: list[dict[str, Any]] = [] for terminal in terminals: tid = str(terminal.get("_terminal_id") or terminal.get("terminal_id") or "") if not tid or tid in seen: continue seen.add(tid) result.append(terminal) return result def limited_unique(values: Iterable[str], limit: int) -> list[str]: seen: set[str] = set() result: list[str] = [] for value in values: if not value or not value.strip() or value in seen: continue seen.add(value) result.append(value) if len(result) >= limit: break return result def edge_labels(node: dict[str, Any], limit: int) -> list[str]: labels = [] for edge in node.get("children") or []: if isinstance(edge, dict) and edge.get("label") is not None: labels.append(str(edge["label"])) return limited_unique(labels, limit) def aggregate_examples(terminals: list[dict[str, Any]], key: str, limit: int) -> list[str]: values: list[str] = [] for terminal in terminals: if key == "prefix": value = terminal.get("prefix") if value is not None: values.append(str(value)) else: values.extend(string_list(terminal.get(key))) return limited_unique(values, limit) def aggregate_needs_review(terminals: list[dict[str, Any]]) -> bool: for index, terminal in enumerate(terminals): annotations = terminal.get("annotations") if isinstance(annotations, dict) and annotations.get("needs_llm_review") is not None: if bool(annotations["needs_llm_review"]): return True continue if heuristic_patch(terminal, index)["needs_llm_review"]: return True return False def annotation_template() -> dict[str, Any]: return { "episode_title_suffixes": [], "media_suffixes": [], "title_candidates": [], "notes": None, } def recommended_prompt(kind: str, terminal_count: int) -> str: if kind == "shared_suffix": return ( "Review the shared DAG suffix examples and mark episode-title text, media metadata, " f"and possible title candidates for {terminal_count} linked terminals." ) return "Review this terminal cluster and mark episode-title text, media metadata, and title candidates." def make_unit( node: dict[str, Any], node_id_value: int, terminals: list[dict[str, Any]], kind: str, example_count: int, ) -> dict[str, Any]: terminal_ids = [str(terminal["_terminal_id"]) for terminal in terminals] reachable_weight = int_field(node, "reachable_weight") if reachable_weight <= 0: reachable_weight = sum(int_field(terminal, "weight", int_field(terminal, "count", 1)) for terminal in terminals) return { "unit_id": f"dag-node-{node_id_value}", "node_id": node_id_value, "kind": kind, "incoming_count": int_field(node, "incoming_count"), "reachable_terminals": len(terminals), "reachable_weight": reachable_weight, "terminal_ids": terminal_ids, "prefix_examples": aggregate_examples(terminals, "prefix", example_count), "value_examples": aggregate_examples(terminals, "value_examples", example_count), "suffix_examples": aggregate_examples(terminals, "suffix_examples", example_count), "common_edge_labels": edge_labels(node, example_count), "needs_llm_review": aggregate_needs_review(terminals), "recommended_prompt": recommended_prompt(kind, len(terminals)), "annotations": annotation_template(), } def selected_units(dag: dict[str, Any], args: Args) -> list[dict[str, Any]]: nodes, terminals_by_node = build_indexes(dag) root = int(dag.get("root", 0) or 0) candidates: list[tuple[tuple[int, int, int, int, int], dict[str, Any]]] = [] for current_id, node in nodes.items(): terminals = reachable_terminals(current_id, nodes, terminals_by_node) if not terminals: continue incoming_count = int_field(node, "incoming_count") reachable_count = len(terminals) by_shared_incoming = incoming_count >= args.min_incoming_count by_reachable = current_id != root and reachable_count >= args.min_reachable_terminals if by_shared_incoming or by_reachable: unit = make_unit(node, current_id, terminals, "shared_suffix", args.example_count) sort_key = ( 0, -unit["reachable_weight"], -unit["reachable_terminals"], -unit["incoming_count"], current_id, ) candidates.append((sort_key, unit)) covered_terminal_ids = { terminal_id for _sort_key, unit in candidates for terminal_id in unit["terminal_ids"] } for current_id, terminals in terminals_by_node.items(): uncovered = [terminal for terminal in terminals if terminal["_terminal_id"] not in covered_terminal_ids] if not uncovered: continue node = nodes.get(current_id, {"id": current_id}) unit = make_unit(node, current_id, uncovered, "terminal_cluster", args.example_count) sort_key = ( 1, -unit["reachable_weight"], -unit["reachable_terminals"], -unit["incoming_count"], current_id, ) candidates.append((sort_key, unit)) candidates.sort(key=lambda item: item[0]) units = [unit for _sort_key, unit in candidates] if args.limit is not None: units = units[: max(0, args.limit)] return units def write_jsonl(path: Path, rows: Iterable[dict[str, Any]]) -> int: path.parent.mkdir(parents=True, exist_ok=True) count = 0 with path.open("w", encoding="utf-8", newline="\n") as handle: for row in rows: handle.write(json.dumps(row, ensure_ascii=False, separators=(",", ":")) + "\n") count += 1 return count def main() -> None: args = parse_args() dag = load_dag(args.dag) units = selected_units(dag, args) count = write_jsonl(args.output, units) summary = { "dag": str(args.dag), "output": str(args.output), "annotation_units": count, "min_reachable_terminals": args.min_reachable_terminals, "min_incoming_count": args.min_incoming_count, "example_count": args.example_count, } print(json.dumps(summary, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()