AniFileBERT / tools /test_annotated_dmhy_workflow.py
ModerRAS's picture
Add DMHY prefix graph annotation workflow
33bb11c
raw
history blame
18 kB
"""Smoke tests for annotated DMHY graph dataset helpers."""
from __future__ import annotations
import tempfile
import json
import subprocess
import sys
import unittest
from pathlib import Path
from tools.annotate_dmhy_prefix_graph import normalize_generated_tokens
from tools.convert_annotated_dmhy_dataset import (
iter_validated_jsonl,
validate_record,
)
from tools.convert_to_char_dataset import convert_record
class AnnotatedDmhyWorkflowTests(unittest.TestCase):
def test_generated_tokens_split_punctuation_and_use_b_only_labels(self) -> None:
tokens, labels = normalize_generated_tokens(
["[ANi]", " ", "Title-Name", "07"],
["B-GROUP", "O", "I-TITLE", "B-EPISODE"],
)
self.assertEqual(tokens, ["[", "ANi", "]", " ", "Title", "-", "Name", "07"])
self.assertEqual(
labels,
["O", "B-GROUP", "O", "O", "B-TITLE", "O", "B-TITLE", "B-EPISODE"],
)
self.assertTrue(all(label == "O" or label.startswith("B-") for label in labels))
def test_preserve_i_labels_keeps_i_on_non_separator_pieces(self) -> None:
tokens, labels = normalize_generated_tokens(
["Title-Name"],
["I-TITLE"],
preserve_i_labels=True,
)
self.assertEqual(tokens, ["Title", "-", "Name"])
self.assertEqual(labels, ["I-TITLE", "O", "I-TITLE"])
def test_validation_rejects_embedded_punctuation(self) -> None:
record = {
"filename": "Title-Name 07",
"tokens": ["Title-Name", "07"],
"labels": ["B-TITLE", "B-EPISODE"],
}
with self.assertRaisesRegex(ValueError, "contains punctuation"):
validate_record(record, Path("sample.jsonl"), 1)
def test_validation_rejects_embedded_symbol_separator(self) -> None:
record = {
"filename": "Title 1920×1080 07",
"tokens": ["Title", "1920×1080", "07"],
"labels": ["B-TITLE", "B-RESOLUTION", "B-EPISODE"],
}
with self.assertRaisesRegex(ValueError, "contains punctuation"):
validate_record(record, Path("sample.jsonl"), 1)
def test_b_only_input_converts_to_char_i_labels(self) -> None:
record = {
"filename": "Title-Name 07",
"tokens": ["Title", "-", "Name", " ", "07"],
"labels": ["B-TITLE", "O", "B-TITLE", "O", "B-EPISODE"],
}
validate_record(record, Path("sample.jsonl"), 1)
converted = convert_record(record)
self.assertIn("I-TITLE", converted["labels"])
self.assertEqual(converted["tokens"][:5], ["T", "i", "t", "l", "e"])
def test_iter_validated_jsonl_accepts_generated_shape(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
path = Path(tmpdir) / "records.jsonl"
path.write_text(
'{"filename":"A 01","tokens":["A"," ","01"],"labels":["B-TITLE","O","B-EPISODE"]}\n',
encoding="utf-8",
)
rows = list(iter_validated_jsonl(path))
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["filename"], "A 01")
def test_cli_smoke_annotate_then_convert_with_temp_files(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
tmp = Path(tmpdir)
graph_path = tmp / "graph.json"
dataset_path = tmp / "dmhy_weak.generated.jsonl"
char_path = tmp / "dmhy_weak.generated_char.jsonl"
vocab_path = tmp / "vocab.generated.char.json"
manifest_path = tmp / "manifest.json"
graph_path.write_text(
json.dumps(
{
"terminals": [
{
"terminal_id": "t0",
"weight": 1,
"value_examples": [
"[ANi] Test Show - 01 [1080P][WEB-DL].mkv"
],
"suffix_examples": [" [1080P][WEB-DL]"],
}
]
},
ensure_ascii=False,
),
encoding="utf-8",
)
annotate = subprocess.run(
[
sys.executable,
"-m",
"tools.annotate_dmhy_prefix_graph",
"--graph",
str(graph_path),
"--output",
str(dataset_path),
"--patch-output",
"",
"--examples-only",
],
check=False,
capture_output=True,
text=True,
)
self.assertEqual(annotate.returncode, 0, annotate.stderr)
rows = [
json.loads(line)
for line in dataset_path.read_text(encoding="utf-8").splitlines()
if line.strip()
]
self.assertEqual(len(rows), 1)
self.assertIn("annotations", rows[0])
self.assertEqual(rows[0]["tokens"][0], "[")
self.assertEqual(rows[0]["labels"][0], "O")
convert = subprocess.run(
[
sys.executable,
"-m",
"tools.convert_annotated_dmhy_dataset",
"--input",
str(dataset_path),
"--output",
str(char_path),
"--vocab-output",
str(vocab_path),
"--manifest-output",
str(manifest_path),
"--progress",
"0",
],
check=False,
capture_output=True,
text=True,
)
self.assertEqual(convert.returncode, 0, convert.stderr)
self.assertTrue(char_path.exists())
self.assertTrue(vocab_path.exists())
self.assertTrue(manifest_path.exists())
def test_cli_source_list_mode_expands_beyond_value_examples(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
tmp = Path(tmpdir)
graph_path = tmp / "graph.json"
source_path = tmp / "dmhy_list.jsonl"
dataset_path = tmp / "dmhy_weak.generated.jsonl"
graph_path.write_text(
json.dumps(
{
"terminals": [
{
"terminal_id": "t0",
"prefix": "[ANi] Full Show - ",
"weight": 10,
"value_examples": [
"[ANi] Full Show - 01 [1080P][WEB-DL].mkv"
],
"suffix_examples": ["01 [1080P][WEB-DL]"],
},
{
"terminal_id": "t1",
"prefix": "[ANi] Other Show - ",
"weight": 10,
"value_examples": [
"[ANi] Other Show - 01 [1080P][WEB-DL].mkv"
],
"suffix_examples": ["01 [1080P][WEB-DL]"],
},
]
},
ensure_ascii=False,
),
encoding="utf-8",
)
source_path.write_text(
"\n".join(
json.dumps({"value": value}, ensure_ascii=False)
for value in [
"[ANi] Full Show - 01 [1080P][WEB-DL].mkv",
"[ANi] Full Show - 02 [1080P][WEB-DL].mkv",
"[ANi] Full Show - 03 [1080P][WEB-DL].mkv",
"[ANi] Other Show - 01 [1080P][WEB-DL].mkv",
]
)
+ "\n",
encoding="utf-8",
)
annotate = subprocess.run(
[
sys.executable,
"-m",
"tools.annotate_dmhy_prefix_graph",
"--graph",
str(graph_path),
"--source-list",
str(source_path),
"--output",
str(dataset_path),
"--patch-output",
"",
"--limit",
"1",
],
check=False,
capture_output=True,
text=True,
)
self.assertEqual(annotate.returncode, 0, annotate.stderr)
rows = [
json.loads(line)
for line in dataset_path.read_text(encoding="utf-8").splitlines()
if line.strip()
]
self.assertEqual(len(rows), 3)
self.assertEqual([row["filename"] for row in rows], [
"[ANi] Full Show - 01 [1080P][WEB-DL].mkv",
"[ANi] Full Show - 02 [1080P][WEB-DL].mkv",
"[ANi] Full Show - 03 [1080P][WEB-DL].mkv",
])
self.assertTrue(all(row["terminal_id"] == "t0" for row in rows))
def test_cli_examples_only_uses_terminal_value_examples(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
tmp = Path(tmpdir)
graph_path = tmp / "graph.json"
source_path = tmp / "dmhy_list.jsonl"
dataset_path = tmp / "dmhy_weak.generated.jsonl"
graph_path.write_text(
json.dumps(
{
"terminals": [
{
"terminal_id": "t0",
"prefix": "[ANi] Example Show - ",
"weight": 10,
"value_examples": [
"[ANi] Example Show - 01 [1080P][WEB-DL].mkv"
],
"suffix_examples": ["01 [1080P][WEB-DL]"],
}
]
},
ensure_ascii=False,
),
encoding="utf-8",
)
source_path.write_text(
"\n".join(
json.dumps({"value": value}, ensure_ascii=False)
for value in [
"[ANi] Example Show - 01 [1080P][WEB-DL].mkv",
"[ANi] Example Show - 02 [1080P][WEB-DL].mkv",
]
)
+ "\n",
encoding="utf-8",
)
annotate = subprocess.run(
[
sys.executable,
"-m",
"tools.annotate_dmhy_prefix_graph",
"--graph",
str(graph_path),
"--source-list",
str(source_path),
"--output",
str(dataset_path),
"--patch-output",
"",
"--examples-only",
],
check=False,
capture_output=True,
text=True,
)
self.assertEqual(annotate.returncode, 0, annotate.stderr)
rows = [
json.loads(line)
for line in dataset_path.read_text(encoding="utf-8").splitlines()
if line.strip()
]
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["filename"], "[ANi] Example Show - 01 [1080P][WEB-DL].mkv")
def test_cli_dag_annotation_units_include_shared_node_terminals(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
tmp = Path(tmpdir)
dag_path = tmp / "dmhy_prefix_dag.json"
output_path = tmp / "dmhy_prefix_dag.annotation_units.jsonl"
dag_path.write_text(
json.dumps(
{
"meta": {"version": "prefix-dag-v1"},
"root": 0,
"nodes": [
{
"id": 0,
"terminal": False,
"children": [
{"label": "A", "target": 1},
{"label": "B", "target": 2},
],
"incoming_count": 0,
"reachable_terminals": 2,
"reachable_weight": 20,
},
{
"id": 1,
"terminal": False,
"children": [{"label": " shared", "target": 3}],
"incoming_count": 1,
"reachable_terminals": 1,
"reachable_weight": 10,
},
{
"id": 2,
"terminal": False,
"children": [{"label": " shared", "target": 3}],
"incoming_count": 1,
"reachable_terminals": 1,
"reachable_weight": 10,
},
{
"id": 3,
"terminal": False,
"children": [
{"label": " 01", "target": 4},
{"label": " 02", "target": 5},
],
"incoming_count": 2,
"reachable_terminals": 2,
"reachable_weight": 20,
},
{
"id": 4,
"terminal": True,
"children": [],
"incoming_count": 1,
"reachable_terminals": 1,
"reachable_weight": 10,
},
{
"id": 5,
"terminal": True,
"children": [],
"incoming_count": 1,
"reachable_terminals": 1,
"reachable_weight": 10,
},
],
"terminals": [
{
"terminal_id": "t0",
"node_id": 4,
"prefix": "Show A shared 01",
"digit_skeleton": "Show A shared <NUM>",
"count": 10,
"weight": 10,
"suffix_examples": [" [1080P][WEB-DL]"],
"value_examples": ["Show A shared 01 [1080P][WEB-DL].mkv"],
"annotations": {},
},
{
"terminal_id": "t1",
"node_id": 5,
"prefix": "Show B shared 02",
"digit_skeleton": "Show B shared <NUM>",
"count": 10,
"weight": 10,
"suffix_examples": [" [1080P][WEB-DL]"],
"value_examples": ["Show B shared 02 [1080P][WEB-DL].mkv"],
"annotations": {},
},
],
},
ensure_ascii=False,
),
encoding="utf-8",
)
annotate = subprocess.run(
[
sys.executable,
"-m",
"tools.annotate_dmhy_prefix_dag",
"--dag",
str(dag_path),
"--output",
str(output_path),
"--min-reachable-terminals",
"2",
"--min-incoming-count",
"2",
"--limit",
"1",
],
check=False,
capture_output=True,
text=True,
)
self.assertEqual(annotate.returncode, 0, annotate.stderr)
rows = [
json.loads(line)
for line in output_path.read_text(encoding="utf-8").splitlines()
if line.strip()
]
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["unit_id"], "dag-node-3")
self.assertEqual(rows[0]["kind"], "shared_suffix")
self.assertEqual(rows[0]["terminal_ids"], ["t0", "t1"])
self.assertEqual(
rows[0]["prefix_examples"],
["Show A shared 01", "Show B shared 02"],
)
self.assertEqual(rows[0]["common_edge_labels"], [" 01", " 02"])
self.assertIn("annotations", rows[0])
if __name__ == "__main__":
unittest.main()