Spaces:
Configuration error
Configuration error
| """Image-level train/val splitting. | |
| Mirrors notebook cell 11. The split is intentionally at the *image* level | |
| (not the caption level): every image owns ~5 captions in COCO, and putting | |
| some of an image's captions in train and others in val would be data leakage. | |
| The notebook does this correctly via the ``img_to_cap_vector`` defaultdict | |
| loop; we preserve that exact algorithm but inject the seed so the split is | |
| reproducible across runs. | |
| """ | |
| from __future__ import annotations | |
| import collections | |
| import random | |
| import pandas as pd | |
| from captioning.utils.logging import get_logger | |
| log = get_logger(__name__) | |
| def make_image_level_splits( | |
| captions: pd.DataFrame, | |
| train_fraction: float = 0.8, | |
| seed: int | None = None, | |
| ) -> tuple[list[str], list[str], list[str], list[str]]: | |
| """Split captions into train/val while keeping all of an image's | |
| captions in the same split. | |
| Mirrors notebook cell 11 exactly when ``seed`` is the same value that was | |
| fed to ``random.seed`` before the notebook ran. ``seed=None`` reproduces | |
| the notebook's non-deterministic behaviour. | |
| Args: | |
| captions: DataFrame with ``image`` and ``caption`` columns | |
| (preprocessed if you want preprocessed splits — the loader applies | |
| ``preprocess_caption`` upstream). | |
| train_fraction: Fraction of *unique images* assigned to the train | |
| split. The notebook uses ``int(len(img_keys) * 0.8)``, which we | |
| preserve byte-for-byte (``int()`` truncates, not rounds). | |
| seed: If provided, used to seed Python's ``random`` for the shuffle. | |
| Returns: | |
| Tuple ``(train_imgs, train_captions, val_imgs, val_captions)`` where | |
| each list has one entry per (image, caption) pair, expanded so an | |
| image with N captions appears N times. | |
| """ | |
| img_to_cap = collections.defaultdict(list) | |
| for img, cap in zip(captions["image"], captions["caption"], strict=True): | |
| img_to_cap[img].append(cap) | |
| img_keys = list(img_to_cap.keys()) | |
| if seed is not None: | |
| rng = random.Random(seed) # — seeded RNG is reproducible by design | |
| rng.shuffle(img_keys) | |
| else: | |
| random.shuffle(img_keys) | |
| slice_index = int(len(img_keys) * train_fraction) | |
| train_keys, val_keys = img_keys[:slice_index], img_keys[slice_index:] | |
| train_imgs: list[str] = [] | |
| train_captions: list[str] = [] | |
| for k in train_keys: | |
| n = len(img_to_cap[k]) | |
| train_imgs.extend([k] * n) | |
| train_captions.extend(img_to_cap[k]) | |
| val_imgs: list[str] = [] | |
| val_captions: list[str] = [] | |
| for k in val_keys: | |
| n = len(img_to_cap[k]) | |
| val_imgs.extend([k] * n) | |
| val_captions.extend(img_to_cap[k]) | |
| log.info( | |
| "splits_made", | |
| train_images=len(train_keys), | |
| val_images=len(val_keys), | |
| train_captions=len(train_captions), | |
| val_captions=len(val_captions), | |
| ) | |
| return train_imgs, train_captions, val_imgs, val_captions | |