| """Tests for dataset building and label logic.""" |
| import pytest |
|
|
| from src.datasets.combined_pairs_dataset import ( |
| CombinedPairsConfig, |
| CombinedPairsDataset, |
| ID2LABEL, |
| LABEL2ID, |
| NUM_LABELS, |
| ) |
| from src.datasets.create_recipes_dataset import format_recipe_as_document |
| from src.schemas.sentence_labels import SentenceLabels |
|
|
|
|
| class TestSentenceLabels: |
| def test_label_values(self): |
| labels = SentenceLabels() |
| assert labels.SAME_PARAGRAPH == 0 |
| assert labels.NEW_PARAGRAPH == 1 |
| assert labels.NEWLINE == 2 |
|
|
| def test_num_labels(self): |
| assert NUM_LABELS == 3 |
|
|
| def test_id2label_mapping(self): |
| assert ID2LABEL[0] == "SAME_PARAGRAPH" |
| assert ID2LABEL[1] == "NEW_PARAGRAPH" |
| assert ID2LABEL[2] == "NEWLINE" |
|
|
| def test_label2id_is_inverse(self): |
| for k, v in ID2LABEL.items(): |
| assert LABEL2ID[v] == k |
|
|
|
|
| class TestCombinedPairsConfig: |
| def test_default_config(self): |
| cfg = CombinedPairsConfig() |
| assert cfg.seed == 42 |
| assert cfg.max_length == 512 |
| assert cfg.gutenberg_train_cap == 45_000 |
| assert cfg.recipes_train_cap is None |
| assert cfg.exclude_domains == set() |
|
|
| def test_exclude_domains(self): |
| cfg = CombinedPairsConfig(exclude_domains={"gutenberg", "recipes"}) |
| assert "gutenberg" in cfg.exclude_domains |
| assert "recipes" in cfg.exclude_domains |
|
|
|
|
| class TestCombinedPairsDataset: |
| def test_build_splits_returns_three_splits(self): |
| cfg = CombinedPairsConfig() |
| builder = CombinedPairsDataset(cfg) |
| splits = builder.build_splits() |
| assert "train" in splits |
| assert "val" in splits |
| assert "test" in splits |
|
|
| def test_build_splits_pairs_have_required_fields(self): |
| cfg = CombinedPairsConfig() |
| builder = CombinedPairsDataset(cfg) |
| splits = builder.build_splits() |
| for pair in splits["train"][:10]: |
| assert "sentence1" in pair |
| assert "sentence2" in pair |
| assert "label" in pair |
| assert pair["label"] in (0, 1, 2) |
|
|
| def test_exclude_domain_reduces_pairs(self): |
| cfg_all = CombinedPairsConfig() |
| cfg_excl = CombinedPairsConfig(exclude_domains={"gutenberg"}) |
| splits_all = CombinedPairsDataset(cfg_all).build_splits() |
| splits_excl = CombinedPairsDataset(cfg_excl).build_splits() |
| assert len(splits_excl["train"]) < len(splits_all["train"]) |
|
|
| def test_class_weights_shape(self): |
| cfg = CombinedPairsConfig() |
| builder = CombinedPairsDataset(cfg) |
| splits = builder.build_splits() |
| weights = builder.compute_class_weights(splits["train"]) |
| assert weights.shape == (NUM_LABELS,) |
| assert all(w > 0 for w in weights) |
|
|
|
|
| class TestFormatRecipe: |
| def test_basic_format(self): |
| rec = { |
| "title": "Test Recipe", |
| "ingredients": '["1 cup flour", "2 eggs"]', |
| "directions": '["Mix ingredients.", "Bake at 350."]', |
| } |
| doc = format_recipe_as_document(rec) |
| assert "Test Recipe" in doc |
| assert "Ingredients:" in doc |
| assert "Directions:" in doc |
| assert "1 cup flour" in doc |
| assert "2 eggs" in doc |
| assert "Mix ingredients." in doc |
| assert "Bake at 350." in doc |
|
|
| def test_has_paragraph_breaks(self): |
| rec = { |
| "title": "Test Recipe", |
| "ingredients": '["flour"]', |
| "directions": '["Mix."]', |
| } |
| doc = format_recipe_as_document(rec) |
| assert "\n\n" in doc |
|
|
| def test_ingredients_on_separate_lines(self): |
| rec = { |
| "title": "Test", |
| "ingredients": '["flour", "sugar", "butter"]', |
| "directions": '["Mix."]', |
| } |
| doc = format_recipe_as_document(rec) |
| |
| ing_section = doc.split("Ingredients:\n")[1].split("\n\n")[0] |
| lines = [l for l in ing_section.split("\n") if l.strip()] |
| assert len(lines) == 3 |
|
|
| def test_directions_are_numbered(self): |
| rec = { |
| "title": "Test", |
| "ingredients": '["flour"]', |
| "directions": '["Step one.", "Step two."]', |
| } |
| doc = format_recipe_as_document(rec) |
| dir_section = doc.split("Directions:\n")[1] |
| |
| first_line = dir_section.split("\n")[0] |
| assert first_line.startswith("1.") or first_line.startswith("1)") |
|
|