"""Tests for dataset building and label logic.""" import pytest from src.datasets.combined_pairs_dataset import ( CombinedPairsConfig, CombinedPairsDataset, ID2LABEL, LABEL2ID, NUM_LABELS, ) from src.datasets.create_recipes_dataset import format_recipe_as_document from src.schemas.sentence_labels import SentenceLabels class TestSentenceLabels: def test_label_values(self): labels = SentenceLabels() assert labels.SAME_PARAGRAPH == 0 assert labels.NEW_PARAGRAPH == 1 assert labels.NEWLINE == 2 def test_num_labels(self): assert NUM_LABELS == 3 def test_id2label_mapping(self): assert ID2LABEL[0] == "SAME_PARAGRAPH" assert ID2LABEL[1] == "NEW_PARAGRAPH" assert ID2LABEL[2] == "NEWLINE" def test_label2id_is_inverse(self): for k, v in ID2LABEL.items(): assert LABEL2ID[v] == k class TestCombinedPairsConfig: def test_default_config(self): cfg = CombinedPairsConfig() assert cfg.seed == 42 assert cfg.max_length == 512 assert cfg.gutenberg_train_cap == 45_000 assert cfg.recipes_train_cap is None assert cfg.exclude_domains == set() def test_exclude_domains(self): cfg = CombinedPairsConfig(exclude_domains={"gutenberg", "recipes"}) assert "gutenberg" in cfg.exclude_domains assert "recipes" in cfg.exclude_domains class TestCombinedPairsDataset: def test_build_splits_returns_three_splits(self): cfg = CombinedPairsConfig() builder = CombinedPairsDataset(cfg) splits = builder.build_splits() assert "train" in splits assert "val" in splits assert "test" in splits def test_build_splits_pairs_have_required_fields(self): cfg = CombinedPairsConfig() builder = CombinedPairsDataset(cfg) splits = builder.build_splits() for pair in splits["train"][:10]: assert "sentence1" in pair assert "sentence2" in pair assert "label" in pair assert pair["label"] in (0, 1, 2) def test_exclude_domain_reduces_pairs(self): cfg_all = CombinedPairsConfig() cfg_excl = CombinedPairsConfig(exclude_domains={"gutenberg"}) splits_all = CombinedPairsDataset(cfg_all).build_splits() splits_excl = CombinedPairsDataset(cfg_excl).build_splits() assert len(splits_excl["train"]) < len(splits_all["train"]) def test_class_weights_shape(self): cfg = CombinedPairsConfig() builder = CombinedPairsDataset(cfg) splits = builder.build_splits() weights = builder.compute_class_weights(splits["train"]) assert weights.shape == (NUM_LABELS,) assert all(w > 0 for w in weights) class TestFormatRecipe: def test_basic_format(self): rec = { "title": "Test Recipe", "ingredients": '["1 cup flour", "2 eggs"]', "directions": '["Mix ingredients.", "Bake at 350."]', } doc = format_recipe_as_document(rec) assert "Test Recipe" in doc assert "Ingredients:" in doc assert "Directions:" in doc assert "1 cup flour" in doc assert "2 eggs" in doc assert "Mix ingredients." in doc assert "Bake at 350." in doc def test_has_paragraph_breaks(self): rec = { "title": "Test Recipe", "ingredients": '["flour"]', "directions": '["Mix."]', } doc = format_recipe_as_document(rec) assert "\n\n" in doc def test_ingredients_on_separate_lines(self): rec = { "title": "Test", "ingredients": '["flour", "sugar", "butter"]', "directions": '["Mix."]', } doc = format_recipe_as_document(rec) # ingredients section should have newlines between items ing_section = doc.split("Ingredients:\n")[1].split("\n\n")[0] lines = [l for l in ing_section.split("\n") if l.strip()] assert len(lines) == 3 def test_directions_are_numbered(self): rec = { "title": "Test", "ingredients": '["flour"]', "directions": '["Step one.", "Step two."]', } doc = format_recipe_as_document(rec) dir_section = doc.split("Directions:\n")[1] # should start with 1. or 1) first_line = dir_section.split("\n")[0] assert first_line.startswith("1.") or first_line.startswith("1)")