File size: 5,462 Bytes
e34b94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import logging
from typing import Dict, List
from datasets import DatasetDict, load_dataset
from larm.data.builders.base_builder import BaseDatasetBuilder
from larm.data.interactions.singleturn_interaction import SingleTurnInteractionManager
from larm.common.registry import registry
from larm.data.envs.math_vision_env import MathVisionEnv
@registry.register_builder("math_vision")
class MathVisionBuilder(BaseDatasetBuilder):
"""
Math Vision dataset builder.
Expects preprocessed data generated by scripts/math_vision_process.py
with fields: prompt, completion, solution, image_path
Expected directory structure:
data_path/
train.json # preprocessed samples with prompt, completion, solution, image_path
valid.json
test.json
"""
DATASET_CONFIG_DICT = {
"default": "configs/datasets/math_vision/default.yaml",
}
CACHE_PATH = None
def _build_datasets(self) -> DatasetDict:
"""Build SFT/RL datasets from preprocessed JSON files.
The preprocessing (download, split, format) should be done via:
python scripts/math_vision_process.py --config configs/latent_memory/math_vision.yaml
Data is always loaded from 'data/math_vision' directory.
"""
# Always use data/math_vision as the data path (unified location for both SFT and GRPO)
data_path = "data/math_vision"
# logging.info(f"[Math_Vision] Using unified data path: {data_path}")
# Load preprocessed splits
data_files = {}
for split_name in ["train", "valid", "test"]:
json_path = os.path.join(data_path, f"{split_name}.json")
if os.path.exists(json_path):
data_files[split_name] = json_path
if len(data_files) == 0:
raise FileNotFoundError(
f"No preprocessed data found in {data_path}. "
f"Please run 'python scripts/math_vision_process.py --output_dir {data_path}' first."
)
# logging.info(f"[Math_Vision] Loading preprocessed data from {data_path}")
# logging.info(f"[Math_Vision] Found splits: {list(data_files.keys())}")
dataset_dict = load_dataset("json", data_files=data_files)
# Apply max_samples if specified
max_samples = getattr(self, 'full_config', {}).get("max_samples", None)
if max_samples is None:
max_samples = self.config.get("max_samples", None)
if max_samples is not None and max_samples > 0:
# logging.info(f"[Math_Vision] Applying max_samples={max_samples}")
for split_name in list(dataset_dict.keys()):
original_size = len(dataset_dict[split_name])
if original_size > max_samples:
dataset_dict[split_name] = dataset_dict[split_name].select(range(max_samples))
# logging.info(f"[Math_Vision] {split_name}: {original_size} -> {len(dataset_dict[split_name])} samples")
# else:
# logging.info(f"[Math_Vision] {split_name}: keeping all {original_size} samples")
# Verify expected fields
required_fields = ["prompt", "completion", "solution", "image_path"]
for split_name, ds in dataset_dict.items():
missing_keys = set(required_fields) - set(ds.column_names)
if missing_keys:
raise ValueError(
f"Split '{split_name}' is missing required fields: {missing_keys}. "
f"Please regenerate preprocessed data using scripts/math_vision_process.py"
)
# logging.info(f"[Math_Vision] {split_name}: {len(ds)} samples")
# Adjust completion field based on mode
# sft: use full solution text (completion field)
# grpo: use extracted answer (solution field)
if self.mode == "grpo":
# logging.info("[Math_Vision] Mode: grpo - using solution field (answer only)")
for split_name in dataset_dict.keys():
# For grpo, replace completion with solution (empty completion, solution has answer)
def use_solution(example):
example["completion"] = "" # grpo generates from scratch
return example
dataset_dict[split_name] = dataset_dict[split_name].map(use_solution)
# else:
# logging.info("[Math_Vision] Mode: sft - using completion field (full solution text)")
keep_keys = self._keep_keys()
for split_name in dataset_dict.keys():
dataset_dict[split_name] = dataset_dict[split_name].select_columns(keep_keys)
# Log example
if "train" in dataset_dict and len(dataset_dict["train"]) > 0:
example = dataset_dict["train"][0]
# logging.info(f"[Math_Vision] Example after mode adjustment: {example}")
return dataset_dict
def _build_sft_datasets(self) -> DatasetDict:
return self._build_datasets()
def _build_rl_datasets(self) -> DatasetDict:
return self._build_datasets()
@classmethod
def _keep_keys(cls) -> List[str]:
"""Required fields in preprocessed data."""
return ["prompt", "completion", "solution", "image_path"]
def get_env_cls(self):
return MathVisionEnv
def get_generation_manager_cls(self):
return SingleTurnInteractionManager
|