stylsteer-vlm / src /data /personality_caps.py
abka03's picture
Deploy StyleSteer-VLM demo
e6f24ae verified
"""Track C β€” Personality-Captions dataset loader.
Personality-Captions (Meta AI) provides images with personality-conditioned captions.
Uses top-10 personalities by frequency from the train split.
Ground truth is available.
"""
import json
import logging
from collections import Counter
from pathlib import Path
from typing import Any, Dict, List, Optional
from src.data.base import StyleDataset
logger = logging.getLogger(__name__)
# Estimated top-10 β€” will be replaced by actual frequency count in S1
DEFAULT_TOP10 = [
"sweet", "dramatic", "cheerful", "sarcastic", "nerdy",
"witty", "melancholic", "philosophical", "anxious", "confident",
]
def compute_top10_personalities(data_dir: str, n: int = 10) -> List[str]:
"""Compute top-N personalities by frequency from train split.
This function MUST be called during S1 scaffold to confirm
the actual top-10 personality list.
"""
train_path = Path(data_dir) / "train.json"
if not train_path.exists():
# Try ParlAI format
train_path = Path(data_dir) / "personality_captions_train.json"
if not train_path.exists():
logger.warning(f"Train file not found at {train_path}. Using default top-10.")
return DEFAULT_TOP10[:n]
with open(train_path) as f:
raw = json.load(f)
counter = Counter()
items = raw if isinstance(raw, list) else raw.get("data", raw.get("annotations", []))
for item in items:
personality = item.get("personality", item.get("style", ""))
if isinstance(personality, list):
personality = personality[0] if personality else ""
personality = personality.strip().lower()
if personality:
counter[personality] += 1
top_n = [p for p, _ in counter.most_common(n)]
logger.info(f"Top-{n} personalities by frequency: {top_n}")
logger.info(f"Frequency counts: {counter.most_common(n)}")
return top_n
class PersonalityCapsDataset(StyleDataset):
"""Personality-Captions dataset (Track C).
Expected directory structure:
data_dir/
β”œβ”€β”€ train.json / personality_captions_train.json
β”œβ”€β”€ test.json / personality_captions_test.json
β”œβ”€β”€ val.json / personality_captions_val.json
└── images/ # or yfcc_images/
"""
def __init__(self, *args, top_k: int = 10, **kwargs):
self.top_k = top_k
self._top_personalities: Optional[List[str]] = None
super().__init__(*args, **kwargs)
@property
def track(self) -> str:
return "C"
@property
def styles(self) -> List[str]:
if self._top_personalities is None:
self._top_personalities = compute_top10_personalities(
str(self.data_dir), self.top_k
)
return self._top_personalities
@property
def has_ground_truth(self) -> bool:
return True
def _load_data(self) -> List[Dict[str, Any]]:
split_map = {
"train": ["train.json", "personality_captions_train.json"],
"test": ["test.json", "personality_captions_test.json"],
"val": ["val.json", "personality_captions_val.json"],
}
filenames = split_map.get(self.split, [f"{self.split}.json"])
raw = None
for fname in filenames:
fpath = self.data_dir / fname
if fpath.exists():
with open(fpath) as f:
raw = json.load(f)
break
if raw is None:
logger.warning(f"PersonalityCaps: no data found for split={self.split}. Using mock.")
return self._mock_data()
items_list = raw if isinstance(raw, list) else raw.get("data", raw.get("annotations", []))
valid_styles = set(self.styles)
data = []
for item in items_list:
personality = item.get("personality", item.get("style", ""))
if isinstance(personality, list):
personality = personality[0] if personality else ""
personality = personality.strip().lower()
if personality not in valid_styles:
continue
image_id = str(item.get("image_hash", item.get("image_id", item.get("id", ""))))
# Image path β€” try multiple conventions
img_filename = item.get("image_path", item.get("filename", ""))
if not img_filename:
img_filename = f"{image_id}.jpg"
image_path = str(self.data_dir / "images" / img_filename)
if not Path(image_path).exists():
image_path = str(self.data_dir / "yfcc_images" / img_filename)
caption = item.get("comment", item.get("caption", item.get("text", "")))
data.append({
"image_id": image_id,
"image_path": image_path,
"style": personality,
"caption_gt": [caption] if caption else [],
})
# Aggregate
data = self._aggregate_captions(data)
logger.info(f"PersonalityCaps {self.split}: {len(data)} items across {len(valid_styles)} personalities")
return data
def _aggregate_captions(self, data: List[Dict]) -> List[Dict]:
key_map = {}
for d in data:
key = (d["image_id"], d["style"])
if key not in key_map:
key_map[key] = {
"image_id": d["image_id"],
"image_path": d["image_path"],
"style": d["style"],
"caption_gt": [],
}
key_map[key]["caption_gt"].extend(d["caption_gt"])
return list(key_map.values())
def _mock_data(self) -> List[Dict[str, Any]]:
data = []
for i in range(max(self.n_images or 5, 5)):
for style in self.styles:
data.append({
"image_id": f"mock_{i}",
"image_path": f"mock_image_{i}.jpg",
"style": style,
"caption_gt": [f"A mock {style} personality caption for image {i}."],
})
return data