"""Smoke tests for annotated DMHY graph dataset helpers.""" from __future__ import annotations import tempfile import json import subprocess import sys import unittest from pathlib import Path from tools.annotate_dmhy_prefix_graph import normalize_generated_tokens from tools.convert_annotated_dmhy_dataset import ( iter_validated_jsonl, validate_record, ) from tools.convert_to_char_dataset import convert_record class AnnotatedDmhyWorkflowTests(unittest.TestCase): def test_generated_tokens_split_punctuation_and_use_b_only_labels(self) -> None: tokens, labels = normalize_generated_tokens( ["[ANi]", " ", "Title-Name", "07"], ["B-GROUP", "O", "I-TITLE", "B-EPISODE"], ) self.assertEqual(tokens, ["[", "ANi", "]", " ", "Title", "-", "Name", "07"]) self.assertEqual( labels, ["O", "B-GROUP", "O", "O", "B-TITLE", "O", "B-TITLE", "B-EPISODE"], ) self.assertTrue(all(label == "O" or label.startswith("B-") for label in labels)) def test_preserve_i_labels_keeps_i_on_non_separator_pieces(self) -> None: tokens, labels = normalize_generated_tokens( ["Title-Name"], ["I-TITLE"], preserve_i_labels=True, ) self.assertEqual(tokens, ["Title", "-", "Name"]) self.assertEqual(labels, ["I-TITLE", "O", "I-TITLE"]) def test_validation_rejects_embedded_punctuation(self) -> None: record = { "filename": "Title-Name 07", "tokens": ["Title-Name", "07"], "labels": ["B-TITLE", "B-EPISODE"], } with self.assertRaisesRegex(ValueError, "contains punctuation"): validate_record(record, Path("sample.jsonl"), 1) def test_validation_rejects_embedded_symbol_separator(self) -> None: record = { "filename": "Title 1920×1080 07", "tokens": ["Title", "1920×1080", "07"], "labels": ["B-TITLE", "B-RESOLUTION", "B-EPISODE"], } with self.assertRaisesRegex(ValueError, "contains punctuation"): validate_record(record, Path("sample.jsonl"), 1) def test_b_only_input_converts_to_char_i_labels(self) -> None: record = { "filename": "Title-Name 07", "tokens": ["Title", "-", "Name", " ", "07"], "labels": ["B-TITLE", "O", "B-TITLE", "O", "B-EPISODE"], } validate_record(record, Path("sample.jsonl"), 1) converted = convert_record(record) self.assertIn("I-TITLE", converted["labels"]) self.assertEqual(converted["tokens"][:5], ["T", "i", "t", "l", "e"]) def test_iter_validated_jsonl_accepts_generated_shape(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "records.jsonl" path.write_text( '{"filename":"A 01","tokens":["A"," ","01"],"labels":["B-TITLE","O","B-EPISODE"]}\n', encoding="utf-8", ) rows = list(iter_validated_jsonl(path)) self.assertEqual(len(rows), 1) self.assertEqual(rows[0]["filename"], "A 01") def test_cli_smoke_annotate_then_convert_with_temp_files(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: tmp = Path(tmpdir) graph_path = tmp / "graph.json" dataset_path = tmp / "dmhy_weak.generated.jsonl" char_path = tmp / "dmhy_weak.generated_char.jsonl" vocab_path = tmp / "vocab.generated.char.json" manifest_path = tmp / "manifest.json" graph_path.write_text( json.dumps( { "terminals": [ { "terminal_id": "t0", "weight": 1, "value_examples": [ "[ANi] Test Show - 01 [1080P][WEB-DL].mkv" ], "suffix_examples": [" [1080P][WEB-DL]"], } ] }, ensure_ascii=False, ), encoding="utf-8", ) annotate = subprocess.run( [ sys.executable, "-m", "tools.annotate_dmhy_prefix_graph", "--graph", str(graph_path), "--output", str(dataset_path), "--patch-output", "", "--examples-only", ], check=False, capture_output=True, text=True, ) self.assertEqual(annotate.returncode, 0, annotate.stderr) rows = [ json.loads(line) for line in dataset_path.read_text(encoding="utf-8").splitlines() if line.strip() ] self.assertEqual(len(rows), 1) self.assertIn("annotations", rows[0]) self.assertEqual(rows[0]["tokens"][0], "[") self.assertEqual(rows[0]["labels"][0], "O") convert = subprocess.run( [ sys.executable, "-m", "tools.convert_annotated_dmhy_dataset", "--input", str(dataset_path), "--output", str(char_path), "--vocab-output", str(vocab_path), "--manifest-output", str(manifest_path), "--progress", "0", ], check=False, capture_output=True, text=True, ) self.assertEqual(convert.returncode, 0, convert.stderr) self.assertTrue(char_path.exists()) self.assertTrue(vocab_path.exists()) self.assertTrue(manifest_path.exists()) def test_cli_source_list_mode_expands_beyond_value_examples(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: tmp = Path(tmpdir) graph_path = tmp / "graph.json" source_path = tmp / "dmhy_list.jsonl" dataset_path = tmp / "dmhy_weak.generated.jsonl" graph_path.write_text( json.dumps( { "terminals": [ { "terminal_id": "t0", "prefix": "[ANi] Full Show - ", "weight": 10, "value_examples": [ "[ANi] Full Show - 01 [1080P][WEB-DL].mkv" ], "suffix_examples": ["01 [1080P][WEB-DL]"], }, { "terminal_id": "t1", "prefix": "[ANi] Other Show - ", "weight": 10, "value_examples": [ "[ANi] Other Show - 01 [1080P][WEB-DL].mkv" ], "suffix_examples": ["01 [1080P][WEB-DL]"], }, ] }, ensure_ascii=False, ), encoding="utf-8", ) source_path.write_text( "\n".join( json.dumps({"value": value}, ensure_ascii=False) for value in [ "[ANi] Full Show - 01 [1080P][WEB-DL].mkv", "[ANi] Full Show - 02 [1080P][WEB-DL].mkv", "[ANi] Full Show - 03 [1080P][WEB-DL].mkv", "[ANi] Other Show - 01 [1080P][WEB-DL].mkv", ] ) + "\n", encoding="utf-8", ) annotate = subprocess.run( [ sys.executable, "-m", "tools.annotate_dmhy_prefix_graph", "--graph", str(graph_path), "--source-list", str(source_path), "--output", str(dataset_path), "--patch-output", "", "--limit", "1", ], check=False, capture_output=True, text=True, ) self.assertEqual(annotate.returncode, 0, annotate.stderr) rows = [ json.loads(line) for line in dataset_path.read_text(encoding="utf-8").splitlines() if line.strip() ] self.assertEqual(len(rows), 3) self.assertEqual([row["filename"] for row in rows], [ "[ANi] Full Show - 01 [1080P][WEB-DL].mkv", "[ANi] Full Show - 02 [1080P][WEB-DL].mkv", "[ANi] Full Show - 03 [1080P][WEB-DL].mkv", ]) self.assertTrue(all(row["terminal_id"] == "t0" for row in rows)) def test_cli_examples_only_uses_terminal_value_examples(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: tmp = Path(tmpdir) graph_path = tmp / "graph.json" source_path = tmp / "dmhy_list.jsonl" dataset_path = tmp / "dmhy_weak.generated.jsonl" graph_path.write_text( json.dumps( { "terminals": [ { "terminal_id": "t0", "prefix": "[ANi] Example Show - ", "weight": 10, "value_examples": [ "[ANi] Example Show - 01 [1080P][WEB-DL].mkv" ], "suffix_examples": ["01 [1080P][WEB-DL]"], } ] }, ensure_ascii=False, ), encoding="utf-8", ) source_path.write_text( "\n".join( json.dumps({"value": value}, ensure_ascii=False) for value in [ "[ANi] Example Show - 01 [1080P][WEB-DL].mkv", "[ANi] Example Show - 02 [1080P][WEB-DL].mkv", ] ) + "\n", encoding="utf-8", ) annotate = subprocess.run( [ sys.executable, "-m", "tools.annotate_dmhy_prefix_graph", "--graph", str(graph_path), "--source-list", str(source_path), "--output", str(dataset_path), "--patch-output", "", "--examples-only", ], check=False, capture_output=True, text=True, ) self.assertEqual(annotate.returncode, 0, annotate.stderr) rows = [ json.loads(line) for line in dataset_path.read_text(encoding="utf-8").splitlines() if line.strip() ] self.assertEqual(len(rows), 1) self.assertEqual(rows[0]["filename"], "[ANi] Example Show - 01 [1080P][WEB-DL].mkv") def test_cli_dag_annotation_units_include_shared_node_terminals(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: tmp = Path(tmpdir) dag_path = tmp / "dmhy_prefix_dag.json" output_path = tmp / "dmhy_prefix_dag.annotation_units.jsonl" dag_path.write_text( json.dumps( { "meta": {"version": "prefix-dag-v1"}, "root": 0, "nodes": [ { "id": 0, "terminal": False, "children": [ {"label": "A", "target": 1}, {"label": "B", "target": 2}, ], "incoming_count": 0, "reachable_terminals": 2, "reachable_weight": 20, }, { "id": 1, "terminal": False, "children": [{"label": " shared", "target": 3}], "incoming_count": 1, "reachable_terminals": 1, "reachable_weight": 10, }, { "id": 2, "terminal": False, "children": [{"label": " shared", "target": 3}], "incoming_count": 1, "reachable_terminals": 1, "reachable_weight": 10, }, { "id": 3, "terminal": False, "children": [ {"label": " 01", "target": 4}, {"label": " 02", "target": 5}, ], "incoming_count": 2, "reachable_terminals": 2, "reachable_weight": 20, }, { "id": 4, "terminal": True, "children": [], "incoming_count": 1, "reachable_terminals": 1, "reachable_weight": 10, }, { "id": 5, "terminal": True, "children": [], "incoming_count": 1, "reachable_terminals": 1, "reachable_weight": 10, }, ], "terminals": [ { "terminal_id": "t0", "node_id": 4, "prefix": "Show A shared 01", "digit_skeleton": "Show A shared ", "count": 10, "weight": 10, "suffix_examples": [" [1080P][WEB-DL]"], "value_examples": ["Show A shared 01 [1080P][WEB-DL].mkv"], "annotations": {}, }, { "terminal_id": "t1", "node_id": 5, "prefix": "Show B shared 02", "digit_skeleton": "Show B shared ", "count": 10, "weight": 10, "suffix_examples": [" [1080P][WEB-DL]"], "value_examples": ["Show B shared 02 [1080P][WEB-DL].mkv"], "annotations": {}, }, ], }, ensure_ascii=False, ), encoding="utf-8", ) annotate = subprocess.run( [ sys.executable, "-m", "tools.annotate_dmhy_prefix_dag", "--dag", str(dag_path), "--output", str(output_path), "--min-reachable-terminals", "2", "--min-incoming-count", "2", "--limit", "1", ], check=False, capture_output=True, text=True, ) self.assertEqual(annotate.returncode, 0, annotate.stderr) rows = [ json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines() if line.strip() ] self.assertEqual(len(rows), 1) self.assertEqual(rows[0]["unit_id"], "dag-node-3") self.assertEqual(rows[0]["kind"], "shared_suffix") self.assertEqual(rows[0]["terminal_ids"], ["t0", "t1"]) self.assertEqual( rows[0]["prefix_examples"], ["Show A shared 01", "Show B shared 02"], ) self.assertEqual(rows[0]["common_edge_labels"], [" 01", " 02"]) self.assertIn("annotations", rows[0]) if __name__ == "__main__": unittest.main()