|
|
import os |
|
|
import logging |
|
|
from typing import Dict, List |
|
|
|
|
|
from datasets import DatasetDict, load_dataset |
|
|
|
|
|
from larm.data.builders.base_builder import BaseDatasetBuilder |
|
|
from larm.data.interactions.singleturn_interaction import SingleTurnInteractionManager |
|
|
from larm.common.registry import registry |
|
|
from larm.data.envs.math_vision_env import MathVisionEnv |
|
|
|
|
|
|
|
|
@registry.register_builder("math_vision") |
|
|
class MathVisionBuilder(BaseDatasetBuilder): |
|
|
""" |
|
|
Math Vision dataset builder. |
|
|
|
|
|
Expects preprocessed data generated by scripts/math_vision_process.py |
|
|
with fields: prompt, completion, solution, image_path |
|
|
|
|
|
Expected directory structure: |
|
|
data_path/ |
|
|
train.json # preprocessed samples with prompt, completion, solution, image_path |
|
|
valid.json |
|
|
test.json |
|
|
""" |
|
|
|
|
|
DATASET_CONFIG_DICT = { |
|
|
"default": "configs/datasets/math_vision/default.yaml", |
|
|
} |
|
|
CACHE_PATH = None |
|
|
|
|
|
def _build_datasets(self) -> DatasetDict: |
|
|
"""Build SFT/RL datasets from preprocessed JSON files. |
|
|
|
|
|
The preprocessing (download, split, format) should be done via: |
|
|
python scripts/math_vision_process.py --config configs/latent_memory/math_vision.yaml |
|
|
|
|
|
Data is always loaded from 'data/math_vision' directory. |
|
|
""" |
|
|
|
|
|
data_path = "data/math_vision" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_files = {} |
|
|
for split_name in ["train", "valid", "test"]: |
|
|
json_path = os.path.join(data_path, f"{split_name}.json") |
|
|
if os.path.exists(json_path): |
|
|
data_files[split_name] = json_path |
|
|
|
|
|
if len(data_files) == 0: |
|
|
raise FileNotFoundError( |
|
|
f"No preprocessed data found in {data_path}. " |
|
|
f"Please run 'python scripts/math_vision_process.py --output_dir {data_path}' first." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_dict = load_dataset("json", data_files=data_files) |
|
|
|
|
|
|
|
|
max_samples = getattr(self, 'full_config', {}).get("max_samples", None) |
|
|
if max_samples is None: |
|
|
max_samples = self.config.get("max_samples", None) |
|
|
|
|
|
if max_samples is not None and max_samples > 0: |
|
|
|
|
|
for split_name in list(dataset_dict.keys()): |
|
|
original_size = len(dataset_dict[split_name]) |
|
|
if original_size > max_samples: |
|
|
dataset_dict[split_name] = dataset_dict[split_name].select(range(max_samples)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
required_fields = ["prompt", "completion", "solution", "image_path"] |
|
|
for split_name, ds in dataset_dict.items(): |
|
|
missing_keys = set(required_fields) - set(ds.column_names) |
|
|
if missing_keys: |
|
|
raise ValueError( |
|
|
f"Split '{split_name}' is missing required fields: {missing_keys}. " |
|
|
f"Please regenerate preprocessed data using scripts/math_vision_process.py" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.mode == "grpo": |
|
|
|
|
|
for split_name in dataset_dict.keys(): |
|
|
|
|
|
def use_solution(example): |
|
|
example["completion"] = "" |
|
|
return example |
|
|
dataset_dict[split_name] = dataset_dict[split_name].map(use_solution) |
|
|
|
|
|
|
|
|
|
|
|
keep_keys = self._keep_keys() |
|
|
for split_name in dataset_dict.keys(): |
|
|
dataset_dict[split_name] = dataset_dict[split_name].select_columns(keep_keys) |
|
|
|
|
|
|
|
|
if "train" in dataset_dict and len(dataset_dict["train"]) > 0: |
|
|
example = dataset_dict["train"][0] |
|
|
|
|
|
|
|
|
return dataset_dict |
|
|
|
|
|
def _build_sft_datasets(self) -> DatasetDict: |
|
|
return self._build_datasets() |
|
|
|
|
|
def _build_rl_datasets(self) -> DatasetDict: |
|
|
return self._build_datasets() |
|
|
|
|
|
@classmethod |
|
|
def _keep_keys(cls) -> List[str]: |
|
|
"""Required fields in preprocessed data.""" |
|
|
return ["prompt", "completion", "solution", "image_path"] |
|
|
|
|
|
def get_env_cls(self): |
|
|
return MathVisionEnv |
|
|
|
|
|
def get_generation_manager_cls(self): |
|
|
return SingleTurnInteractionManager |
|
|
|
|
|
|