image-captioning-api / tests /unit /test_splits.py
apoorvrajdev's picture
feat: finalize Phase 1 modular ML architecture
3a2e5f0
"""Tests for ``captioning.data.splits.make_image_level_splits``."""
from __future__ import annotations
import pandas as pd
from captioning.data.splits import make_image_level_splits
def _build_corpus(n_images: int = 10, captions_per_image: int = 5) -> pd.DataFrame:
rows = []
for i in range(n_images):
for j in range(captions_per_image):
rows.append({"image": f"/img/{i}.jpg", "caption": f"caption {i}-{j}"})
return pd.DataFrame(rows)
def test_splits_are_image_level() -> None:
"""The same image must NOT appear in both train and val — that's the
whole point of doing image-level (rather than caption-level) splitting."""
df = _build_corpus(n_images=10, captions_per_image=5)
train_imgs, _, val_imgs, _ = make_image_level_splits(df, train_fraction=0.8, seed=0)
assert set(train_imgs).isdisjoint(set(val_imgs))
def test_splits_preserve_total_count() -> None:
df = _build_corpus(n_images=10, captions_per_image=5)
train_imgs, train_caps, val_imgs, val_caps = make_image_level_splits(
df, train_fraction=0.8, seed=0
)
assert len(train_imgs) == len(train_caps)
assert len(val_imgs) == len(val_caps)
assert len(train_caps) + len(val_caps) == len(df)
def test_splits_are_seed_reproducible() -> None:
df = _build_corpus(n_images=20, captions_per_image=3)
a = make_image_level_splits(df, train_fraction=0.8, seed=123)
b = make_image_level_splits(df, train_fraction=0.8, seed=123)
assert a == b
def test_splits_seed_changes_partition() -> None:
"""Different seeds should (almost always) produce different splits."""
df = _build_corpus(n_images=20, captions_per_image=3)
a_train, _, _, _ = make_image_level_splits(df, train_fraction=0.8, seed=1)
b_train, _, _, _ = make_image_level_splits(df, train_fraction=0.8, seed=2)
assert a_train != b_train
def test_train_fraction_uses_int_truncation_like_notebook() -> None:
"""Notebook cell 11 uses ``int(len(img_keys) * 0.8)``. With 10 images and
fraction 0.85, that gives 8 train / 2 val. ``round`` would give 9/1.
Preserve the notebook's int() behaviour."""
df = _build_corpus(n_images=10, captions_per_image=2)
train_imgs, _, val_imgs, _ = make_image_level_splits(df, train_fraction=0.85, seed=0)
train_unique = len(set(train_imgs))
val_unique = len(set(val_imgs))
assert train_unique == 8
assert val_unique == 2