File size: 2,982 Bytes
1d1e600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Image-level train/val splitting.

Mirrors notebook cell 11. The split is intentionally at the *image* level
(not the caption level): every image owns ~5 captions in COCO, and putting
some of an image's captions in train and others in val would be data leakage.

The notebook does this correctly via the ``img_to_cap_vector`` defaultdict
loop; we preserve that exact algorithm but inject the seed so the split is
reproducible across runs.
"""

from __future__ import annotations

import collections
import random

import pandas as pd

from captioning.utils.logging import get_logger

log = get_logger(__name__)


def make_image_level_splits(
    captions: pd.DataFrame,
    train_fraction: float = 0.8,
    seed: int | None = None,
) -> tuple[list[str], list[str], list[str], list[str]]:
    """Split captions into train/val while keeping all of an image's
    captions in the same split.

    Mirrors notebook cell 11 exactly when ``seed`` is the same value that was
    fed to ``random.seed`` before the notebook ran. ``seed=None`` reproduces
    the notebook's non-deterministic behaviour.

    Args:
        captions: DataFrame with ``image`` and ``caption`` columns
            (preprocessed if you want preprocessed splits — the loader applies
            ``preprocess_caption`` upstream).
        train_fraction: Fraction of *unique images* assigned to the train
            split. The notebook uses ``int(len(img_keys) * 0.8)``, which we
            preserve byte-for-byte (``int()`` truncates, not rounds).
        seed: If provided, used to seed Python's ``random`` for the shuffle.

    Returns:
        Tuple ``(train_imgs, train_captions, val_imgs, val_captions)`` where
        each list has one entry per (image, caption) pair, expanded so an
        image with N captions appears N times.
    """
    img_to_cap = collections.defaultdict(list)
    for img, cap in zip(captions["image"], captions["caption"], strict=True):
        img_to_cap[img].append(cap)

    img_keys = list(img_to_cap.keys())

    if seed is not None:
        rng = random.Random(seed)  # — seeded RNG is reproducible by design
        rng.shuffle(img_keys)
    else:
        random.shuffle(img_keys)

    slice_index = int(len(img_keys) * train_fraction)
    train_keys, val_keys = img_keys[:slice_index], img_keys[slice_index:]

    train_imgs: list[str] = []
    train_captions: list[str] = []
    for k in train_keys:
        n = len(img_to_cap[k])
        train_imgs.extend([k] * n)
        train_captions.extend(img_to_cap[k])

    val_imgs: list[str] = []
    val_captions: list[str] = []
    for k in val_keys:
        n = len(img_to_cap[k])
        val_imgs.extend([k] * n)
        val_captions.extend(img_to_cap[k])

    log.info(
        "splits_made",
        train_images=len(train_keys),
        val_images=len(val_keys),
        train_captions=len(train_captions),
        val_captions=len(val_captions),
    )
    return train_imgs, train_captions, val_imgs, val_captions