File size: 5,462 Bytes
e34b94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import logging
from typing import Dict, List

from datasets import DatasetDict, load_dataset

from larm.data.builders.base_builder import BaseDatasetBuilder
from larm.data.interactions.singleturn_interaction import SingleTurnInteractionManager
from larm.common.registry import registry
from larm.data.envs.math_vision_env import MathVisionEnv


@registry.register_builder("math_vision")
class MathVisionBuilder(BaseDatasetBuilder):
    """
    Math Vision dataset builder.
    
    Expects preprocessed data generated by scripts/math_vision_process.py
    with fields: prompt, completion, solution, image_path
    
    Expected directory structure:
        data_path/
          train.json   # preprocessed samples with prompt, completion, solution, image_path
          valid.json
          test.json
    """

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/math_vision/default.yaml",
    }
    CACHE_PATH = None

    def _build_datasets(self) -> DatasetDict:
        """Build SFT/RL datasets from preprocessed JSON files.
        
        The preprocessing (download, split, format) should be done via:
            python scripts/math_vision_process.py --config configs/latent_memory/math_vision.yaml
        
        Data is always loaded from 'data/math_vision' directory.
        """
        # Always use data/math_vision as the data path (unified location for both SFT and GRPO)
        data_path = "data/math_vision"

        # logging.info(f"[Math_Vision] Using unified data path: {data_path}")

        # Load preprocessed splits
        data_files = {}
        for split_name in ["train", "valid", "test"]:
            json_path = os.path.join(data_path, f"{split_name}.json")
            if os.path.exists(json_path):
                data_files[split_name] = json_path

        if len(data_files) == 0:
            raise FileNotFoundError(
                f"No preprocessed data found in {data_path}. "
                f"Please run 'python scripts/math_vision_process.py --output_dir {data_path}' first."
            )

        # logging.info(f"[Math_Vision] Loading preprocessed data from {data_path}")
        # logging.info(f"[Math_Vision] Found splits: {list(data_files.keys())}")

        dataset_dict = load_dataset("json", data_files=data_files)

        # Apply max_samples if specified
        max_samples = getattr(self, 'full_config', {}).get("max_samples", None)
        if max_samples is None:
            max_samples = self.config.get("max_samples", None)

        if max_samples is not None and max_samples > 0:
            # logging.info(f"[Math_Vision] Applying max_samples={max_samples}")
            for split_name in list(dataset_dict.keys()):
                original_size = len(dataset_dict[split_name])
                if original_size > max_samples:
                    dataset_dict[split_name] = dataset_dict[split_name].select(range(max_samples))
                    # logging.info(f"[Math_Vision] {split_name}: {original_size} -> {len(dataset_dict[split_name])} samples")
                # else:
                    # logging.info(f"[Math_Vision] {split_name}: keeping all {original_size} samples")

        # Verify expected fields
        required_fields = ["prompt", "completion", "solution", "image_path"]
        for split_name, ds in dataset_dict.items():
            missing_keys = set(required_fields) - set(ds.column_names)
            if missing_keys:
                raise ValueError(
                    f"Split '{split_name}' is missing required fields: {missing_keys}. "
                    f"Please regenerate preprocessed data using scripts/math_vision_process.py"
                )
            # logging.info(f"[Math_Vision] {split_name}: {len(ds)} samples")

        # Adjust completion field based on mode
        # sft: use full solution text (completion field)
        # grpo: use extracted answer (solution field)
        if self.mode == "grpo":
            # logging.info("[Math_Vision] Mode: grpo - using solution field (answer only)")
            for split_name in dataset_dict.keys():
                # For grpo, replace completion with solution (empty completion, solution has answer)
                def use_solution(example):
                    example["completion"] = ""  # grpo generates from scratch
                    return example
                dataset_dict[split_name] = dataset_dict[split_name].map(use_solution)
        # else:
            # logging.info("[Math_Vision] Mode: sft - using completion field (full solution text)")

        keep_keys = self._keep_keys()
        for split_name in dataset_dict.keys():
            dataset_dict[split_name] = dataset_dict[split_name].select_columns(keep_keys)

        # Log example
        if "train" in dataset_dict and len(dataset_dict["train"]) > 0:
            example = dataset_dict["train"][0]
            # logging.info(f"[Math_Vision] Example after mode adjustment: {example}")

        return dataset_dict

    def _build_sft_datasets(self) -> DatasetDict:
        return self._build_datasets()

    def _build_rl_datasets(self) -> DatasetDict:
        return self._build_datasets()

    @classmethod
    def _keep_keys(cls) -> List[str]:
        """Required fields in preprocessed data."""
        return ["prompt", "completion", "solution", "image_path"]

    def get_env_cls(self):
        return MathVisionEnv

    def get_generation_manager_cls(self):
        return SingleTurnInteractionManager