import os import logging from typing import Dict, List from datasets import DatasetDict, load_dataset from larm.data.builders.base_builder import BaseDatasetBuilder from larm.data.interactions.singleturn_interaction import SingleTurnInteractionManager from larm.common.registry import registry from larm.data.envs.math_vision_env import MathVisionEnv @registry.register_builder("math_vision") class MathVisionBuilder(BaseDatasetBuilder): """ Math Vision dataset builder. Expects preprocessed data generated by scripts/math_vision_process.py with fields: prompt, completion, solution, image_path Expected directory structure: data_path/ train.json # preprocessed samples with prompt, completion, solution, image_path valid.json test.json """ DATASET_CONFIG_DICT = { "default": "configs/datasets/math_vision/default.yaml", } CACHE_PATH = None def _build_datasets(self) -> DatasetDict: """Build SFT/RL datasets from preprocessed JSON files. The preprocessing (download, split, format) should be done via: python scripts/math_vision_process.py --config configs/latent_memory/math_vision.yaml Data is always loaded from 'data/math_vision' directory. """ # Always use data/math_vision as the data path (unified location for both SFT and GRPO) data_path = "data/math_vision" # logging.info(f"[Math_Vision] Using unified data path: {data_path}") # Load preprocessed splits data_files = {} for split_name in ["train", "valid", "test"]: json_path = os.path.join(data_path, f"{split_name}.json") if os.path.exists(json_path): data_files[split_name] = json_path if len(data_files) == 0: raise FileNotFoundError( f"No preprocessed data found in {data_path}. " f"Please run 'python scripts/math_vision_process.py --output_dir {data_path}' first." ) # logging.info(f"[Math_Vision] Loading preprocessed data from {data_path}") # logging.info(f"[Math_Vision] Found splits: {list(data_files.keys())}") dataset_dict = load_dataset("json", data_files=data_files) # Apply max_samples if specified max_samples = getattr(self, 'full_config', {}).get("max_samples", None) if max_samples is None: max_samples = self.config.get("max_samples", None) if max_samples is not None and max_samples > 0: # logging.info(f"[Math_Vision] Applying max_samples={max_samples}") for split_name in list(dataset_dict.keys()): original_size = len(dataset_dict[split_name]) if original_size > max_samples: dataset_dict[split_name] = dataset_dict[split_name].select(range(max_samples)) # logging.info(f"[Math_Vision] {split_name}: {original_size} -> {len(dataset_dict[split_name])} samples") # else: # logging.info(f"[Math_Vision] {split_name}: keeping all {original_size} samples") # Verify expected fields required_fields = ["prompt", "completion", "solution", "image_path"] for split_name, ds in dataset_dict.items(): missing_keys = set(required_fields) - set(ds.column_names) if missing_keys: raise ValueError( f"Split '{split_name}' is missing required fields: {missing_keys}. " f"Please regenerate preprocessed data using scripts/math_vision_process.py" ) # logging.info(f"[Math_Vision] {split_name}: {len(ds)} samples") # Adjust completion field based on mode # sft: use full solution text (completion field) # grpo: use extracted answer (solution field) if self.mode == "grpo": # logging.info("[Math_Vision] Mode: grpo - using solution field (answer only)") for split_name in dataset_dict.keys(): # For grpo, replace completion with solution (empty completion, solution has answer) def use_solution(example): example["completion"] = "" # grpo generates from scratch return example dataset_dict[split_name] = dataset_dict[split_name].map(use_solution) # else: # logging.info("[Math_Vision] Mode: sft - using completion field (full solution text)") keep_keys = self._keep_keys() for split_name in dataset_dict.keys(): dataset_dict[split_name] = dataset_dict[split_name].select_columns(keep_keys) # Log example if "train" in dataset_dict and len(dataset_dict["train"]) > 0: example = dataset_dict["train"][0] # logging.info(f"[Math_Vision] Example after mode adjustment: {example}") return dataset_dict def _build_sft_datasets(self) -> DatasetDict: return self._build_datasets() def _build_rl_datasets(self) -> DatasetDict: return self._build_datasets() @classmethod def _keep_keys(cls) -> List[str]: """Required fields in preprocessed data.""" return ["prompt", "completion", "solution", "image_path"] def get_env_cls(self): return MathVisionEnv def get_generation_manager_cls(self): return SingleTurnInteractionManager