File size: 686 Bytes
e6f24ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""Data loaders for StyleSteer-VLM."""

from src.data.base import StyleDataset
from src.data.senticap import SentiCapDataset
from src.data.flickrstyle10k import FlickrStyle10KDataset
from src.data.personality_caps import PersonalityCapsDataset
from src.data.coco import COCODataset

TRACK_DATASETS = {
    "A": SentiCapDataset,
    "B": FlickrStyle10KDataset,
    "C": PersonalityCapsDataset,
    "D": COCODataset,
}


def get_dataset(track: str, **kwargs) -> StyleDataset:
    """Get dataset by track letter."""
    if track not in TRACK_DATASETS:
        raise ValueError(f"Unknown track: {track}. Available: {list(TRACK_DATASETS.keys())}")
    return TRACK_DATASETS[track](**kwargs)