AniFileBERT / tools /annotate_dmhy_prefix_dag.py
ModerRAS's picture
Add DMHY prefix graph annotation workflow
33bb11c
raw
history blame
12.1 kB
"""Build annotation units from a DMHY prefix DAG.
The DAG producer keeps repeated suffix structure shared across many raw
prefixes. This tool turns those shared nodes into compact, traceable units for
LLM or human review without calling any remote service.
"""
from __future__ import annotations
import argparse
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable
from tools.annotate_dmhy_prefix_graph import (
heuristic_patch,
string_list,
)
DEFAULT_DAG = Path("datasets/AnimeName/dmhy_prefix_dag.json")
DEFAULT_OUTPUT = Path("datasets/AnimeName/dmhy_prefix_dag.annotation_units.jsonl")
@dataclass(frozen=True)
class Args:
dag: Path
output: Path
min_reachable_terminals: int
min_incoming_count: int
limit: int | None
example_count: int
def parse_args() -> Args:
parser = argparse.ArgumentParser(
description="Emit DAG-aware DMHY prefix annotation units as JSONL"
)
parser.add_argument("--dag", type=Path, default=DEFAULT_DAG, help="Input dmhy_prefix_dag.json")
parser.add_argument(
"--output",
type=Path,
default=DEFAULT_OUTPUT,
help="Output annotation unit JSONL",
)
parser.add_argument(
"--min-reachable-terminals",
type=int,
default=2,
help="Select non-root nodes reaching at least this many terminals",
)
parser.add_argument(
"--min-incoming-count",
type=int,
default=2,
help="Select nodes with at least this many incoming DAG edges",
)
parser.add_argument("--limit", type=int, default=None, help="Maximum units to write")
parser.add_argument(
"--example-count",
type=int,
default=5,
help="Maximum examples retained per example field",
)
ns = parser.parse_args()
return Args(
dag=ns.dag,
output=ns.output,
min_reachable_terminals=max(1, ns.min_reachable_terminals),
min_incoming_count=max(1, ns.min_incoming_count),
limit=ns.limit,
example_count=max(1, ns.example_count),
)
def load_dag(path: Path) -> dict[str, Any]:
if not path.exists():
raise SystemExit(f"DAG not found: {path}")
try:
dag = json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError as exc:
raise SystemExit(f"invalid DAG JSON in {path}: {exc}") from exc
if not isinstance(dag, dict):
raise SystemExit(f"invalid DAG schema in {path}: root must be an object")
if not isinstance(dag.get("nodes"), list):
raise SystemExit(f"invalid DAG schema in {path}: missing nodes list")
if not isinstance(dag.get("terminals"), list):
raise SystemExit(f"invalid DAG schema in {path}: missing terminals list")
return dag
def node_id(node: dict[str, Any], fallback: int) -> int:
value = node.get("id", fallback)
try:
return int(value)
except (TypeError, ValueError):
raise SystemExit(f"invalid node id: {value!r}") from None
def terminal_id(terminal: dict[str, Any], fallback: int) -> str:
value = terminal.get("terminal_id", terminal.get("id", fallback))
return str(value)
def int_field(row: dict[str, Any], key: str, default: int = 0) -> int:
try:
return int(row.get(key, default) or default)
except (TypeError, ValueError):
return default
def build_indexes(dag: dict[str, Any]) -> tuple[dict[int, dict[str, Any]], dict[int, list[dict[str, Any]]]]:
nodes: dict[int, dict[str, Any]] = {}
for fallback, node in enumerate(dag["nodes"]):
if not isinstance(node, dict):
continue
nodes[node_id(node, fallback)] = node
terminals_by_node: dict[int, list[dict[str, Any]]] = {}
for fallback, terminal in enumerate(dag["terminals"]):
if not isinstance(terminal, dict):
continue
try:
terminal_node_id = int(terminal.get("node_id"))
except (TypeError, ValueError):
continue
terminal = dict(terminal)
terminal["_terminal_id"] = terminal_id(terminal, fallback)
terminal["_terminal_index"] = fallback
terminals_by_node.setdefault(terminal_node_id, []).append(terminal)
return nodes, terminals_by_node
def reachable_terminals(
start_node_id: int,
nodes: dict[int, dict[str, Any]],
terminals_by_node: dict[int, list[dict[str, Any]]],
) -> list[dict[str, Any]]:
memo: dict[int, list[dict[str, Any]]] = {}
visiting: set[int] = set()
def visit(current_id: int) -> list[dict[str, Any]]:
if current_id in memo:
return memo[current_id]
if current_id in visiting:
raise SystemExit(f"cycle detected while traversing DAG at node {current_id}")
visiting.add(current_id)
found = list(terminals_by_node.get(current_id, []))
node = nodes.get(current_id, {})
for edge in node.get("children") or []:
if not isinstance(edge, dict):
continue
try:
target = int(edge.get("target"))
except (TypeError, ValueError):
continue
found.extend(visit(target))
visiting.remove(current_id)
deduped = dedupe_terminals(found)
memo[current_id] = deduped
return deduped
return visit(start_node_id)
def dedupe_terminals(terminals: Iterable[dict[str, Any]]) -> list[dict[str, Any]]:
seen: set[str] = set()
result: list[dict[str, Any]] = []
for terminal in terminals:
tid = str(terminal.get("_terminal_id") or terminal.get("terminal_id") or "")
if not tid or tid in seen:
continue
seen.add(tid)
result.append(terminal)
return result
def limited_unique(values: Iterable[str], limit: int) -> list[str]:
seen: set[str] = set()
result: list[str] = []
for value in values:
if not value or not value.strip() or value in seen:
continue
seen.add(value)
result.append(value)
if len(result) >= limit:
break
return result
def edge_labels(node: dict[str, Any], limit: int) -> list[str]:
labels = []
for edge in node.get("children") or []:
if isinstance(edge, dict) and edge.get("label") is not None:
labels.append(str(edge["label"]))
return limited_unique(labels, limit)
def aggregate_examples(terminals: list[dict[str, Any]], key: str, limit: int) -> list[str]:
values: list[str] = []
for terminal in terminals:
if key == "prefix":
value = terminal.get("prefix")
if value is not None:
values.append(str(value))
else:
values.extend(string_list(terminal.get(key)))
return limited_unique(values, limit)
def aggregate_needs_review(terminals: list[dict[str, Any]]) -> bool:
for index, terminal in enumerate(terminals):
annotations = terminal.get("annotations")
if isinstance(annotations, dict) and annotations.get("needs_llm_review") is not None:
if bool(annotations["needs_llm_review"]):
return True
continue
if heuristic_patch(terminal, index)["needs_llm_review"]:
return True
return False
def annotation_template() -> dict[str, Any]:
return {
"episode_title_suffixes": [],
"media_suffixes": [],
"title_candidates": [],
"notes": None,
}
def recommended_prompt(kind: str, terminal_count: int) -> str:
if kind == "shared_suffix":
return (
"Review the shared DAG suffix examples and mark episode-title text, media metadata, "
f"and possible title candidates for {terminal_count} linked terminals."
)
return "Review this terminal cluster and mark episode-title text, media metadata, and title candidates."
def make_unit(
node: dict[str, Any],
node_id_value: int,
terminals: list[dict[str, Any]],
kind: str,
example_count: int,
) -> dict[str, Any]:
terminal_ids = [str(terminal["_terminal_id"]) for terminal in terminals]
reachable_weight = int_field(node, "reachable_weight")
if reachable_weight <= 0:
reachable_weight = sum(int_field(terminal, "weight", int_field(terminal, "count", 1)) for terminal in terminals)
return {
"unit_id": f"dag-node-{node_id_value}",
"node_id": node_id_value,
"kind": kind,
"incoming_count": int_field(node, "incoming_count"),
"reachable_terminals": len(terminals),
"reachable_weight": reachable_weight,
"terminal_ids": terminal_ids,
"prefix_examples": aggregate_examples(terminals, "prefix", example_count),
"value_examples": aggregate_examples(terminals, "value_examples", example_count),
"suffix_examples": aggregate_examples(terminals, "suffix_examples", example_count),
"common_edge_labels": edge_labels(node, example_count),
"needs_llm_review": aggregate_needs_review(terminals),
"recommended_prompt": recommended_prompt(kind, len(terminals)),
"annotations": annotation_template(),
}
def selected_units(dag: dict[str, Any], args: Args) -> list[dict[str, Any]]:
nodes, terminals_by_node = build_indexes(dag)
root = int(dag.get("root", 0) or 0)
candidates: list[tuple[tuple[int, int, int, int, int], dict[str, Any]]] = []
for current_id, node in nodes.items():
terminals = reachable_terminals(current_id, nodes, terminals_by_node)
if not terminals:
continue
incoming_count = int_field(node, "incoming_count")
reachable_count = len(terminals)
by_shared_incoming = incoming_count >= args.min_incoming_count
by_reachable = current_id != root and reachable_count >= args.min_reachable_terminals
if by_shared_incoming or by_reachable:
unit = make_unit(node, current_id, terminals, "shared_suffix", args.example_count)
sort_key = (
0,
-unit["reachable_weight"],
-unit["reachable_terminals"],
-unit["incoming_count"],
current_id,
)
candidates.append((sort_key, unit))
covered_terminal_ids = {
terminal_id
for _sort_key, unit in candidates
for terminal_id in unit["terminal_ids"]
}
for current_id, terminals in terminals_by_node.items():
uncovered = [terminal for terminal in terminals if terminal["_terminal_id"] not in covered_terminal_ids]
if not uncovered:
continue
node = nodes.get(current_id, {"id": current_id})
unit = make_unit(node, current_id, uncovered, "terminal_cluster", args.example_count)
sort_key = (
1,
-unit["reachable_weight"],
-unit["reachable_terminals"],
-unit["incoming_count"],
current_id,
)
candidates.append((sort_key, unit))
candidates.sort(key=lambda item: item[0])
units = [unit for _sort_key, unit in candidates]
if args.limit is not None:
units = units[: max(0, args.limit)]
return units
def write_jsonl(path: Path, rows: Iterable[dict[str, Any]]) -> int:
path.parent.mkdir(parents=True, exist_ok=True)
count = 0
with path.open("w", encoding="utf-8", newline="\n") as handle:
for row in rows:
handle.write(json.dumps(row, ensure_ascii=False, separators=(",", ":")) + "\n")
count += 1
return count
def main() -> None:
args = parse_args()
dag = load_dag(args.dag)
units = selected_units(dag, args)
count = write_jsonl(args.output, units)
summary = {
"dag": str(args.dag),
"output": str(args.output),
"annotation_units": count,
"min_reachable_terminals": args.min_reachable_terminals,
"min_incoming_count": args.min_incoming_count,
"example_count": args.example_count,
}
print(json.dumps(summary, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()