abka03's picture
Deploy StyleSteer-VLM demo
e6f24ae verified
"""Abstract base class for style datasets."""
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from PIL import Image
logger = logging.getLogger(__name__)
class StyleDataset(ABC):
"""Base class for all style captioning datasets.
Each dataset provides:
- Images for probing (train split) and evaluation (test split)
- Style labels
- Optionally: ground-truth styled captions (for GT metrics)
"""
def __init__(
self,
data_dir: str,
split: str = "test",
n_images: Optional[int] = None,
seed: int = 42,
):
self.data_dir = Path(data_dir)
self.split = split
self.n_images = n_images
self.seed = seed
self._data: Optional[List[Dict[str, Any]]] = None
@property
@abstractmethod
def track(self) -> str:
"""Track letter (A, B, C, D)."""
...
@property
@abstractmethod
def styles(self) -> List[str]:
"""List of style names in this track."""
...
@property
@abstractmethod
def has_ground_truth(self) -> bool:
"""Whether this dataset has ground-truth styled captions."""
...
@abstractmethod
def _load_data(self) -> List[Dict[str, Any]]:
"""Load raw data from disk.
Returns list of dicts with at minimum:
- "image_id": str or int
- "image_path": str (absolute path to image)
- "style": str (style name)
- "caption_gt": Optional[List[str]] (ground-truth captions, if available)
"""
...
@property
def data(self) -> List[Dict[str, Any]]:
"""Lazy-loaded data."""
if self._data is None:
self._data = self._load_data()
return self._data
def get_images(self, style: str) -> List[Dict[str, Any]]:
"""Get all items for a given style."""
items = [d for d in self.data if d["style"] == style]
if self.n_images is not None:
import random
rng = random.Random(self.seed)
items = rng.sample(items, min(self.n_images, len(items)))
return items
def load_image(self, image_path: str) -> Image.Image:
"""Load a PIL Image from path."""
return Image.open(image_path).convert("RGB")
def get_ground_truth(self, image_id: str, style: str) -> Optional[List[str]]:
"""Get ground-truth captions for an image+style pair."""
if not self.has_ground_truth:
return None
items = [d for d in self.data if d["image_id"] == image_id and d["style"] == style]
if not items:
return None
refs = []
for item in items:
if item.get("caption_gt"):
refs.extend(item["caption_gt"] if isinstance(item["caption_gt"], list) else [item["caption_gt"]])
return refs if refs else None
def __len__(self) -> int:
return len(self.data)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(track={self.track}, split={self.split}, n={len(self)})"