File size: 3,854 Bytes
e6f24ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Track D β€” COCO val2017 dataset loader.

Used for custom rhetorical styles (poetic, scientific, narrative, factual,
natural, fictional). No ground-truth styled captions β€” LLM judge only.
"""

import json
import logging
import random
from pathlib import Path
from typing import Any, Dict, List, Optional

from src.data.base import StyleDataset

logger = logging.getLogger(__name__)


class COCODataset(StyleDataset):
    """COCO val2017 for custom rhetorical tracks (Track D).

    Expected directory structure:
        COCO_IMAGE_DIR/   (from .env, e.g. data/coco_val2017/images/)
            β”œβ”€β”€ 000000000139.jpg
            └── ...
        COCO_ANNOT_FILE/  (from .env)
            └── captions_val2017.json

    No ground-truth styled captions. Factual COCO captions used as neutral reference only.
    """

    RHETORICAL_STYLES = ["poetic", "scientific", "narrative", "factual", "natural", "fictional"]

    def __init__(
        self,
        data_dir: str = "",
        image_dir: str = "",
        annot_file: str = "",
        **kwargs,
    ):
        self.image_dir = Path(image_dir) if image_dir else None
        self.annot_file = Path(annot_file) if annot_file else None
        super().__init__(data_dir=data_dir, **kwargs)

    @property
    def track(self) -> str:
        return "D"

    @property
    def styles(self) -> List[str]:
        return self.RHETORICAL_STYLES

    @property
    def has_ground_truth(self) -> bool:
        return False

    def _load_data(self) -> List[Dict[str, Any]]:
        # Resolve image directory
        img_dir = self.image_dir or self.data_dir / "images"
        annot = self.annot_file or self.data_dir / "annotations" / "captions_val2017.json"

        if not Path(img_dir).exists():
            logger.warning(f"COCO images not found at {img_dir}. Using mock data.")
            return self._mock_data()

        # Load annotation file to get image list
        image_ids = []
        if Path(annot).exists():
            with open(annot) as f:
                coco_data = json.load(f)
            image_ids = [img["id"] for img in coco_data.get("images", [])]
        else:
            # Fall back to listing image files
            image_ids = [
                int(p.stem) for p in Path(img_dir).glob("*.jpg")
            ]

        # Deterministic subset
        rng = random.Random(self.seed)
        n = self.n_images or 500
        if len(image_ids) > n:
            image_ids = rng.sample(image_ids, n)
        image_ids.sort()

        # Create entries β€” one per image per style
        data = []
        for img_id in image_ids:
            filename = f"{img_id:012d}.jpg"
            image_path = str(Path(img_dir) / filename)
            for style in self.styles:
                data.append({
                    "image_id": str(img_id),
                    "image_path": image_path,
                    "style": style,
                    "caption_gt": None,  # No ground truth for Track D
                })

        logger.info(f"COCO Track D: {len(image_ids)} images Γ— {len(self.styles)} styles = {len(data)} items")
        return data

    def get_images(self, style: str) -> List[Dict[str, Any]]:
        """Override to return unique images (not per-style duplicates)."""
        items = [d for d in self.data if d["style"] == style]
        # Already de-duplicated at image level since each image has one entry per style
        return items

    def _mock_data(self) -> List[Dict[str, Any]]:
        data = []
        for i in range(max(self.n_images or 5, 5)):
            for style in self.styles:
                data.append({
                    "image_id": f"mock_{i}",
                    "image_path": f"mock_image_{i}.jpg",
                    "style": style,
                    "caption_gt": None,
                })
        return data