File size: 3,133 Bytes
e6f24ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Abstract base class for style datasets."""

import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

from PIL import Image

logger = logging.getLogger(__name__)


class StyleDataset(ABC):
    """Base class for all style captioning datasets.

    Each dataset provides:
    - Images for probing (train split) and evaluation (test split)
    - Style labels
    - Optionally: ground-truth styled captions (for GT metrics)
    """

    def __init__(
        self,
        data_dir: str,
        split: str = "test",
        n_images: Optional[int] = None,
        seed: int = 42,
    ):
        self.data_dir = Path(data_dir)
        self.split = split
        self.n_images = n_images
        self.seed = seed
        self._data: Optional[List[Dict[str, Any]]] = None

    @property
    @abstractmethod
    def track(self) -> str:
        """Track letter (A, B, C, D)."""
        ...

    @property
    @abstractmethod
    def styles(self) -> List[str]:
        """List of style names in this track."""
        ...

    @property
    @abstractmethod
    def has_ground_truth(self) -> bool:
        """Whether this dataset has ground-truth styled captions."""
        ...

    @abstractmethod
    def _load_data(self) -> List[Dict[str, Any]]:
        """Load raw data from disk.

        Returns list of dicts with at minimum:
            - "image_id": str or int
            - "image_path": str (absolute path to image)
            - "style": str (style name)
            - "caption_gt": Optional[List[str]] (ground-truth captions, if available)
        """
        ...

    @property
    def data(self) -> List[Dict[str, Any]]:
        """Lazy-loaded data."""
        if self._data is None:
            self._data = self._load_data()
        return self._data

    def get_images(self, style: str) -> List[Dict[str, Any]]:
        """Get all items for a given style."""
        items = [d for d in self.data if d["style"] == style]
        if self.n_images is not None:
            import random
            rng = random.Random(self.seed)
            items = rng.sample(items, min(self.n_images, len(items)))
        return items

    def load_image(self, image_path: str) -> Image.Image:
        """Load a PIL Image from path."""
        return Image.open(image_path).convert("RGB")

    def get_ground_truth(self, image_id: str, style: str) -> Optional[List[str]]:
        """Get ground-truth captions for an image+style pair."""
        if not self.has_ground_truth:
            return None
        items = [d for d in self.data if d["image_id"] == image_id and d["style"] == style]
        if not items:
            return None
        refs = []
        for item in items:
            if item.get("caption_gt"):
                refs.extend(item["caption_gt"] if isinstance(item["caption_gt"], list) else [item["caption_gt"]])
        return refs if refs else None

    def __len__(self) -> int:
        return len(self.data)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(track={self.track}, split={self.split}, n={len(self)})"