model111 / larm /data /builders /math_vision_builder.py
LCZZZZ's picture
Upload MemGen code and data
e34b94f verified
import os
import logging
from typing import Dict, List
from datasets import DatasetDict, load_dataset
from larm.data.builders.base_builder import BaseDatasetBuilder
from larm.data.interactions.singleturn_interaction import SingleTurnInteractionManager
from larm.common.registry import registry
from larm.data.envs.math_vision_env import MathVisionEnv
@registry.register_builder("math_vision")
class MathVisionBuilder(BaseDatasetBuilder):
"""
Math Vision dataset builder.
Expects preprocessed data generated by scripts/math_vision_process.py
with fields: prompt, completion, solution, image_path
Expected directory structure:
data_path/
train.json # preprocessed samples with prompt, completion, solution, image_path
valid.json
test.json
"""
DATASET_CONFIG_DICT = {
"default": "configs/datasets/math_vision/default.yaml",
}
CACHE_PATH = None
def _build_datasets(self) -> DatasetDict:
"""Build SFT/RL datasets from preprocessed JSON files.
The preprocessing (download, split, format) should be done via:
python scripts/math_vision_process.py --config configs/latent_memory/math_vision.yaml
Data is always loaded from 'data/math_vision' directory.
"""
# Always use data/math_vision as the data path (unified location for both SFT and GRPO)
data_path = "data/math_vision"
# logging.info(f"[Math_Vision] Using unified data path: {data_path}")
# Load preprocessed splits
data_files = {}
for split_name in ["train", "valid", "test"]:
json_path = os.path.join(data_path, f"{split_name}.json")
if os.path.exists(json_path):
data_files[split_name] = json_path
if len(data_files) == 0:
raise FileNotFoundError(
f"No preprocessed data found in {data_path}. "
f"Please run 'python scripts/math_vision_process.py --output_dir {data_path}' first."
)
# logging.info(f"[Math_Vision] Loading preprocessed data from {data_path}")
# logging.info(f"[Math_Vision] Found splits: {list(data_files.keys())}")
dataset_dict = load_dataset("json", data_files=data_files)
# Apply max_samples if specified
max_samples = getattr(self, 'full_config', {}).get("max_samples", None)
if max_samples is None:
max_samples = self.config.get("max_samples", None)
if max_samples is not None and max_samples > 0:
# logging.info(f"[Math_Vision] Applying max_samples={max_samples}")
for split_name in list(dataset_dict.keys()):
original_size = len(dataset_dict[split_name])
if original_size > max_samples:
dataset_dict[split_name] = dataset_dict[split_name].select(range(max_samples))
# logging.info(f"[Math_Vision] {split_name}: {original_size} -> {len(dataset_dict[split_name])} samples")
# else:
# logging.info(f"[Math_Vision] {split_name}: keeping all {original_size} samples")
# Verify expected fields
required_fields = ["prompt", "completion", "solution", "image_path"]
for split_name, ds in dataset_dict.items():
missing_keys = set(required_fields) - set(ds.column_names)
if missing_keys:
raise ValueError(
f"Split '{split_name}' is missing required fields: {missing_keys}. "
f"Please regenerate preprocessed data using scripts/math_vision_process.py"
)
# logging.info(f"[Math_Vision] {split_name}: {len(ds)} samples")
# Adjust completion field based on mode
# sft: use full solution text (completion field)
# grpo: use extracted answer (solution field)
if self.mode == "grpo":
# logging.info("[Math_Vision] Mode: grpo - using solution field (answer only)")
for split_name in dataset_dict.keys():
# For grpo, replace completion with solution (empty completion, solution has answer)
def use_solution(example):
example["completion"] = "" # grpo generates from scratch
return example
dataset_dict[split_name] = dataset_dict[split_name].map(use_solution)
# else:
# logging.info("[Math_Vision] Mode: sft - using completion field (full solution text)")
keep_keys = self._keep_keys()
for split_name in dataset_dict.keys():
dataset_dict[split_name] = dataset_dict[split_name].select_columns(keep_keys)
# Log example
if "train" in dataset_dict and len(dataset_dict["train"]) > 0:
example = dataset_dict["train"][0]
# logging.info(f"[Math_Vision] Example after mode adjustment: {example}")
return dataset_dict
def _build_sft_datasets(self) -> DatasetDict:
return self._build_datasets()
def _build_rl_datasets(self) -> DatasetDict:
return self._build_datasets()
@classmethod
def _keep_keys(cls) -> List[str]:
"""Required fields in preprocessed data."""
return ["prompt", "completion", "solution", "image_path"]
def get_env_cls(self):
return MathVisionEnv
def get_generation_manager_cls(self):
return SingleTurnInteractionManager