bc-test / tests /test_dataset.py
lamossta's picture
tests
594db4d
"""Tests for dataset building and label logic."""
import pytest
from src.datasets.combined_pairs_dataset import (
CombinedPairsConfig,
CombinedPairsDataset,
ID2LABEL,
LABEL2ID,
NUM_LABELS,
)
from src.datasets.create_recipes_dataset import format_recipe_as_document
from src.schemas.sentence_labels import SentenceLabels
class TestSentenceLabels:
def test_label_values(self):
labels = SentenceLabels()
assert labels.SAME_PARAGRAPH == 0
assert labels.NEW_PARAGRAPH == 1
assert labels.NEWLINE == 2
def test_num_labels(self):
assert NUM_LABELS == 3
def test_id2label_mapping(self):
assert ID2LABEL[0] == "SAME_PARAGRAPH"
assert ID2LABEL[1] == "NEW_PARAGRAPH"
assert ID2LABEL[2] == "NEWLINE"
def test_label2id_is_inverse(self):
for k, v in ID2LABEL.items():
assert LABEL2ID[v] == k
class TestCombinedPairsConfig:
def test_default_config(self):
cfg = CombinedPairsConfig()
assert cfg.seed == 42
assert cfg.max_length == 512
assert cfg.gutenberg_train_cap == 45_000
assert cfg.recipes_train_cap is None
assert cfg.exclude_domains == set()
def test_exclude_domains(self):
cfg = CombinedPairsConfig(exclude_domains={"gutenberg", "recipes"})
assert "gutenberg" in cfg.exclude_domains
assert "recipes" in cfg.exclude_domains
class TestCombinedPairsDataset:
def test_build_splits_returns_three_splits(self):
cfg = CombinedPairsConfig()
builder = CombinedPairsDataset(cfg)
splits = builder.build_splits()
assert "train" in splits
assert "val" in splits
assert "test" in splits
def test_build_splits_pairs_have_required_fields(self):
cfg = CombinedPairsConfig()
builder = CombinedPairsDataset(cfg)
splits = builder.build_splits()
for pair in splits["train"][:10]:
assert "sentence1" in pair
assert "sentence2" in pair
assert "label" in pair
assert pair["label"] in (0, 1, 2)
def test_exclude_domain_reduces_pairs(self):
cfg_all = CombinedPairsConfig()
cfg_excl = CombinedPairsConfig(exclude_domains={"gutenberg"})
splits_all = CombinedPairsDataset(cfg_all).build_splits()
splits_excl = CombinedPairsDataset(cfg_excl).build_splits()
assert len(splits_excl["train"]) < len(splits_all["train"])
def test_class_weights_shape(self):
cfg = CombinedPairsConfig()
builder = CombinedPairsDataset(cfg)
splits = builder.build_splits()
weights = builder.compute_class_weights(splits["train"])
assert weights.shape == (NUM_LABELS,)
assert all(w > 0 for w in weights)
class TestFormatRecipe:
def test_basic_format(self):
rec = {
"title": "Test Recipe",
"ingredients": '["1 cup flour", "2 eggs"]',
"directions": '["Mix ingredients.", "Bake at 350."]',
}
doc = format_recipe_as_document(rec)
assert "Test Recipe" in doc
assert "Ingredients:" in doc
assert "Directions:" in doc
assert "1 cup flour" in doc
assert "2 eggs" in doc
assert "Mix ingredients." in doc
assert "Bake at 350." in doc
def test_has_paragraph_breaks(self):
rec = {
"title": "Test Recipe",
"ingredients": '["flour"]',
"directions": '["Mix."]',
}
doc = format_recipe_as_document(rec)
assert "\n\n" in doc
def test_ingredients_on_separate_lines(self):
rec = {
"title": "Test",
"ingredients": '["flour", "sugar", "butter"]',
"directions": '["Mix."]',
}
doc = format_recipe_as_document(rec)
# ingredients section should have newlines between items
ing_section = doc.split("Ingredients:\n")[1].split("\n\n")[0]
lines = [l for l in ing_section.split("\n") if l.strip()]
assert len(lines) == 3
def test_directions_are_numbered(self):
rec = {
"title": "Test",
"ingredients": '["flour"]',
"directions": '["Step one.", "Step two."]',
}
doc = format_recipe_as_document(rec)
dir_section = doc.split("Directions:\n")[1]
# should start with 1. or 1)
first_line = dir_section.split("\n")[0]
assert first_line.startswith("1.") or first_line.startswith("1)")