Spaces:
Running
Running
| """Abstract base class for style datasets.""" | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from PIL import Image | |
| logger = logging.getLogger(__name__) | |
| class StyleDataset(ABC): | |
| """Base class for all style captioning datasets. | |
| Each dataset provides: | |
| - Images for probing (train split) and evaluation (test split) | |
| - Style labels | |
| - Optionally: ground-truth styled captions (for GT metrics) | |
| """ | |
| def __init__( | |
| self, | |
| data_dir: str, | |
| split: str = "test", | |
| n_images: Optional[int] = None, | |
| seed: int = 42, | |
| ): | |
| self.data_dir = Path(data_dir) | |
| self.split = split | |
| self.n_images = n_images | |
| self.seed = seed | |
| self._data: Optional[List[Dict[str, Any]]] = None | |
| def track(self) -> str: | |
| """Track letter (A, B, C, D).""" | |
| ... | |
| def styles(self) -> List[str]: | |
| """List of style names in this track.""" | |
| ... | |
| def has_ground_truth(self) -> bool: | |
| """Whether this dataset has ground-truth styled captions.""" | |
| ... | |
| def _load_data(self) -> List[Dict[str, Any]]: | |
| """Load raw data from disk. | |
| Returns list of dicts with at minimum: | |
| - "image_id": str or int | |
| - "image_path": str (absolute path to image) | |
| - "style": str (style name) | |
| - "caption_gt": Optional[List[str]] (ground-truth captions, if available) | |
| """ | |
| ... | |
| def data(self) -> List[Dict[str, Any]]: | |
| """Lazy-loaded data.""" | |
| if self._data is None: | |
| self._data = self._load_data() | |
| return self._data | |
| def get_images(self, style: str) -> List[Dict[str, Any]]: | |
| """Get all items for a given style.""" | |
| items = [d for d in self.data if d["style"] == style] | |
| if self.n_images is not None: | |
| import random | |
| rng = random.Random(self.seed) | |
| items = rng.sample(items, min(self.n_images, len(items))) | |
| return items | |
| def load_image(self, image_path: str) -> Image.Image: | |
| """Load a PIL Image from path.""" | |
| return Image.open(image_path).convert("RGB") | |
| def get_ground_truth(self, image_id: str, style: str) -> Optional[List[str]]: | |
| """Get ground-truth captions for an image+style pair.""" | |
| if not self.has_ground_truth: | |
| return None | |
| items = [d for d in self.data if d["image_id"] == image_id and d["style"] == style] | |
| if not items: | |
| return None | |
| refs = [] | |
| for item in items: | |
| if item.get("caption_gt"): | |
| refs.extend(item["caption_gt"] if isinstance(item["caption_gt"], list) else [item["caption_gt"]]) | |
| return refs if refs else None | |
| def __len__(self) -> int: | |
| return len(self.data) | |
| def __repr__(self) -> str: | |
| return f"{self.__class__.__name__}(track={self.track}, split={self.split}, n={len(self)})" | |