File size: 2,422 Bytes
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Tests for ``captioning.data.splits.make_image_level_splits``."""

from __future__ import annotations

import pandas as pd

from captioning.data.splits import make_image_level_splits


def _build_corpus(n_images: int = 10, captions_per_image: int = 5) -> pd.DataFrame:
    rows = []
    for i in range(n_images):
        for j in range(captions_per_image):
            rows.append({"image": f"/img/{i}.jpg", "caption": f"caption {i}-{j}"})
    return pd.DataFrame(rows)


def test_splits_are_image_level() -> None:
    """The same image must NOT appear in both train and val — that's the
    whole point of doing image-level (rather than caption-level) splitting."""
    df = _build_corpus(n_images=10, captions_per_image=5)
    train_imgs, _, val_imgs, _ = make_image_level_splits(df, train_fraction=0.8, seed=0)
    assert set(train_imgs).isdisjoint(set(val_imgs))


def test_splits_preserve_total_count() -> None:
    df = _build_corpus(n_images=10, captions_per_image=5)
    train_imgs, train_caps, val_imgs, val_caps = make_image_level_splits(
        df, train_fraction=0.8, seed=0
    )
    assert len(train_imgs) == len(train_caps)
    assert len(val_imgs) == len(val_caps)
    assert len(train_caps) + len(val_caps) == len(df)


def test_splits_are_seed_reproducible() -> None:
    df = _build_corpus(n_images=20, captions_per_image=3)
    a = make_image_level_splits(df, train_fraction=0.8, seed=123)
    b = make_image_level_splits(df, train_fraction=0.8, seed=123)
    assert a == b


def test_splits_seed_changes_partition() -> None:
    """Different seeds should (almost always) produce different splits."""
    df = _build_corpus(n_images=20, captions_per_image=3)
    a_train, _, _, _ = make_image_level_splits(df, train_fraction=0.8, seed=1)
    b_train, _, _, _ = make_image_level_splits(df, train_fraction=0.8, seed=2)
    assert a_train != b_train


def test_train_fraction_uses_int_truncation_like_notebook() -> None:
    """Notebook cell 11 uses ``int(len(img_keys) * 0.8)``. With 10 images and
    fraction 0.85, that gives 8 train / 2 val. ``round`` would give 9/1.
    Preserve the notebook's int() behaviour."""
    df = _build_corpus(n_images=10, captions_per_image=2)
    train_imgs, _, val_imgs, _ = make_image_level_splits(df, train_fraction=0.85, seed=0)
    train_unique = len(set(train_imgs))
    val_unique = len(set(val_imgs))
    assert train_unique == 8
    assert val_unique == 2