abka03's picture
Deploy StyleSteer-VLM demo
e6f24ae verified
"""Track D β€” COCO val2017 dataset loader.
Used for custom rhetorical styles (poetic, scientific, narrative, factual,
natural, fictional). No ground-truth styled captions β€” LLM judge only.
"""
import json
import logging
import random
from pathlib import Path
from typing import Any, Dict, List, Optional
from src.data.base import StyleDataset
logger = logging.getLogger(__name__)
class COCODataset(StyleDataset):
"""COCO val2017 for custom rhetorical tracks (Track D).
Expected directory structure:
COCO_IMAGE_DIR/ (from .env, e.g. data/coco_val2017/images/)
β”œβ”€β”€ 000000000139.jpg
└── ...
COCO_ANNOT_FILE/ (from .env)
└── captions_val2017.json
No ground-truth styled captions. Factual COCO captions used as neutral reference only.
"""
RHETORICAL_STYLES = ["poetic", "scientific", "narrative", "factual", "natural", "fictional"]
def __init__(
self,
data_dir: str = "",
image_dir: str = "",
annot_file: str = "",
**kwargs,
):
self.image_dir = Path(image_dir) if image_dir else None
self.annot_file = Path(annot_file) if annot_file else None
super().__init__(data_dir=data_dir, **kwargs)
@property
def track(self) -> str:
return "D"
@property
def styles(self) -> List[str]:
return self.RHETORICAL_STYLES
@property
def has_ground_truth(self) -> bool:
return False
def _load_data(self) -> List[Dict[str, Any]]:
# Resolve image directory
img_dir = self.image_dir or self.data_dir / "images"
annot = self.annot_file or self.data_dir / "annotations" / "captions_val2017.json"
if not Path(img_dir).exists():
logger.warning(f"COCO images not found at {img_dir}. Using mock data.")
return self._mock_data()
# Load annotation file to get image list
image_ids = []
if Path(annot).exists():
with open(annot) as f:
coco_data = json.load(f)
image_ids = [img["id"] for img in coco_data.get("images", [])]
else:
# Fall back to listing image files
image_ids = [
int(p.stem) for p in Path(img_dir).glob("*.jpg")
]
# Deterministic subset
rng = random.Random(self.seed)
n = self.n_images or 500
if len(image_ids) > n:
image_ids = rng.sample(image_ids, n)
image_ids.sort()
# Create entries β€” one per image per style
data = []
for img_id in image_ids:
filename = f"{img_id:012d}.jpg"
image_path = str(Path(img_dir) / filename)
for style in self.styles:
data.append({
"image_id": str(img_id),
"image_path": image_path,
"style": style,
"caption_gt": None, # No ground truth for Track D
})
logger.info(f"COCO Track D: {len(image_ids)} images Γ— {len(self.styles)} styles = {len(data)} items")
return data
def get_images(self, style: str) -> List[Dict[str, Any]]:
"""Override to return unique images (not per-style duplicates)."""
items = [d for d in self.data if d["style"] == style]
# Already de-duplicated at image level since each image has one entry per style
return items
def _mock_data(self) -> List[Dict[str, Any]]:
data = []
for i in range(max(self.n_images or 5, 5)):
for style in self.styles:
data.append({
"image_id": f"mock_{i}",
"image_path": f"mock_image_{i}.jpg",
"style": style,
"caption_gt": None,
})
return data