"""Abstract base class for style datasets.""" import logging from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from PIL import Image logger = logging.getLogger(__name__) class StyleDataset(ABC): """Base class for all style captioning datasets. Each dataset provides: - Images for probing (train split) and evaluation (test split) - Style labels - Optionally: ground-truth styled captions (for GT metrics) """ def __init__( self, data_dir: str, split: str = "test", n_images: Optional[int] = None, seed: int = 42, ): self.data_dir = Path(data_dir) self.split = split self.n_images = n_images self.seed = seed self._data: Optional[List[Dict[str, Any]]] = None @property @abstractmethod def track(self) -> str: """Track letter (A, B, C, D).""" ... @property @abstractmethod def styles(self) -> List[str]: """List of style names in this track.""" ... @property @abstractmethod def has_ground_truth(self) -> bool: """Whether this dataset has ground-truth styled captions.""" ... @abstractmethod def _load_data(self) -> List[Dict[str, Any]]: """Load raw data from disk. Returns list of dicts with at minimum: - "image_id": str or int - "image_path": str (absolute path to image) - "style": str (style name) - "caption_gt": Optional[List[str]] (ground-truth captions, if available) """ ... @property def data(self) -> List[Dict[str, Any]]: """Lazy-loaded data.""" if self._data is None: self._data = self._load_data() return self._data def get_images(self, style: str) -> List[Dict[str, Any]]: """Get all items for a given style.""" items = [d for d in self.data if d["style"] == style] if self.n_images is not None: import random rng = random.Random(self.seed) items = rng.sample(items, min(self.n_images, len(items))) return items def load_image(self, image_path: str) -> Image.Image: """Load a PIL Image from path.""" return Image.open(image_path).convert("RGB") def get_ground_truth(self, image_id: str, style: str) -> Optional[List[str]]: """Get ground-truth captions for an image+style pair.""" if not self.has_ground_truth: return None items = [d for d in self.data if d["image_id"] == image_id and d["style"] == style] if not items: return None refs = [] for item in items: if item.get("caption_gt"): refs.extend(item["caption_gt"] if isinstance(item["caption_gt"], list) else [item["caption_gt"]]) return refs if refs else None def __len__(self) -> int: return len(self.data) def __repr__(self) -> str: return f"{self.__class__.__name__}(track={self.track}, split={self.split}, n={len(self)})"