| import itertools |
| import json |
| import logging |
| import os |
| import random |
| import re |
|
|
| import numpy as np |
| import ray |
|
|
| try: |
| import pyarrow.parquet as pq |
| except ImportError: |
| pq = None |
|
|
| from slime.utils.types import MultimodalTypes, Sample |
|
|
| from .timer import Timer |
|
|
| __all__ = ["Dataset"] |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def read_file(path): |
| path, row_slice = _parse_generalized_path(path) |
| reader = None |
|
|
| if not os.path.exists(path): |
| raise FileNotFoundError(f"Prompt dataset path '{path}' does not exist.") |
|
|
| if path.endswith(".jsonl"): |
|
|
| def jsonl_reader(p): |
| with open(p, encoding="utf-8") as f: |
| for line_num, line in enumerate(f): |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| yield json.loads(line) |
| except json.JSONDecodeError as e: |
| print(f"JSON decode error at line {line_num}: {e}") |
| continue |
|
|
| reader = jsonl_reader(path) |
|
|
| elif path.endswith(".parquet"): |
| if pq is None: |
| raise ImportError("pyarrow is required for parquet support") |
|
|
| def parquet_reader(p): |
| pf = pq.ParquetFile(p) |
|
|
| for batch in pf.iter_batches(): |
| yield from batch.to_pylist() |
|
|
| reader = parquet_reader(path) |
|
|
| else: |
| raise ValueError(f"Unsupported file format: {path}. Supported formats are .jsonl and .parquet.") |
|
|
| if row_slice is not None: |
|
|
| logger.info("read_file path=%s applying slice row_slice=%s", path, row_slice) |
| reader = itertools.islice(reader, row_slice.start, row_slice.stop, row_slice.step) |
|
|
| yield from reader |
|
|
|
|
| def _parse_generalized_path(s: str): |
| if (m := re.match(r"^(?P<real_path>.*)@\[(?P<start>-?\d*):(?P<end>-?\d*)\]$", s)) is not None: |
| path = m.group("real_path") |
| start = int(x) if (x := m.group("start")) != "" else None |
| end = int(x) if (x := m.group("end")) != "" else None |
| return path, slice(start, end) |
|
|
| return s, None |
|
|
|
|
| def filter_long_prompt(origin_samples: list[Sample], tokenizer, processor, max_length: int | None) -> list[Sample]: |
| if max_length is None: |
| return origin_samples |
|
|
| if not isinstance(origin_samples[0].prompt, str): |
| logger.warning( |
| "Skipping max_length check for list prompt. Set apply_chat_template=True to enable length filtering." |
| ) |
| return origin_samples |
|
|
| if processor: |
| filtered_samples = [] |
| for sample in origin_samples: |
| from slime.utils.processing_utils import process_vision_info |
|
|
| multimodal_inputs = process_vision_info(sample.prompt, processor) |
| processor_output = processor(text=sample.prompt, **multimodal_inputs) |
| input_ids = processor_output["input_ids"][0] |
| if len(input_ids) <= max_length: |
| filtered_samples.append(sample) |
| else: |
| prompts = [sample.prompt for sample in origin_samples] |
| input_ids_list = tokenizer(prompts, add_special_tokens=False)["input_ids"] |
| filtered_samples = [ |
| sample |
| for sample, input_ids in zip(origin_samples, input_ids_list, strict=True) |
| if len(input_ids) <= max_length |
| ] |
|
|
| logger.info(f"Filtered {len(origin_samples) - len(filtered_samples)} samples longer than max_length={max_length}.") |
|
|
| return filtered_samples |
|
|
|
|
| def _build_messages(data: dict, prompt_key: str, as_conversation: bool, multimodal_keys: dict = None): |
| prompt = data.get(prompt_key) |
|
|
| if isinstance(prompt, str): |
| |
| if not as_conversation: |
| return prompt |
| else: |
| prompt = [{"role": "user", "content": prompt}] |
|
|
| if multimodal_keys: |
| |
| multimodals = {} |
| for type_name, data_key in multimodal_keys.items(): |
| mt = MultimodalTypes.get(type_name) |
| if mt: |
| multimodal_data = data.get(data_key) |
| if multimodal_data is not None: |
| multimodals[mt.placeholder] = (mt, list(multimodal_data)) |
|
|
| pattern = "(" + "|".join(re.escape(p) for p in multimodals.keys()) + ")" |
|
|
| for message in prompt: |
| if isinstance(message["content"], str): |
| content_list = [] |
| for segment in re.split(pattern, message["content"]): |
| if not segment: |
| continue |
| if segment in multimodals: |
| mt, content = multimodals[segment] |
| assert len(content) > 0, ( |
| f"Not enough {mt.name} data: more '{mt.placeholder}' placeholders in prompt " |
| f"than {mt.name}s provided in data" |
| ) |
| content_list.append({"type": mt.name, mt.name: content.pop(0)}) |
| else: |
| content_list.append({"type": "text", "text": segment}) |
| message["content"] = content_list |
|
|
| elif isinstance(message["content"], list): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| logger.warning("message['content'] is a list of dicts, no processing will be done.") |
| continue |
| else: |
| raise ValueError( |
| f"Unsupported content type: {type(message['content'])}, expected str or list of dicts" |
| ) |
|
|
| for placeholder, (mt, remaining) in multimodals.items(): |
| assert len(remaining) == 0, ( |
| f"Multimodal data count mismatch: {len(remaining)} more {mt.name}(s)" |
| f"than '{placeholder}' placeholders in prompt" |
| ) |
|
|
| return prompt |
|
|
|
|
| class Dataset: |
| def __init__( |
| self, |
| path, |
| tokenizer, |
| processor, |
| max_length, |
| *, |
| prompt_key="text", |
| multimodal_keys=None, |
| label_key=None, |
| tool_key=None, |
| metadata_key="metadata", |
| seed=42, |
| apply_chat_template=False, |
| apply_chat_template_kwargs=None, |
| ): |
| origin_samples = [] |
| for data in read_file(path): |
| |
| as_conversation = apply_chat_template or (multimodal_keys is not None) |
| prompt = _build_messages(data, prompt_key, as_conversation, multimodal_keys) |
|
|
| metadata = data.get(metadata_key) or {} |
| tools = None |
| if tool_key is not None and tool_key in data: |
| tools = data[tool_key] |
| if isinstance(tools, str): |
| tools = json.loads(tools) |
| elif isinstance(tools, np.ndarray): |
| tools = tools.tolist() |
| assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead" |
| metadata["tools"] = tools |
|
|
| if apply_chat_template: |
| output_prompt = tokenizer.apply_chat_template( |
| prompt, |
| tools=tools, |
| tokenize=False, |
| add_generation_prompt=True, |
| **(apply_chat_template_kwargs or {}), |
| ) |
| else: |
| output_prompt = prompt |
|
|
| if processor: |
| from slime.utils.processing_utils import process_vision_info |
|
|
| assert isinstance( |
| prompt, list |
| ), f"prompt must be a list when processor is not None, got {type(prompt)} instead" |
| multimodal_inputs = process_vision_info(prompt, processor) |
| else: |
| multimodal_inputs = None |
|
|
| origin_samples.append( |
| Sample( |
| prompt=output_prompt, |
| label=data[label_key] if label_key is not None else None, |
| metadata=metadata, |
| multimodal_inputs=multimodal_inputs, |
| ) |
| ) |
|
|
| if max_length is not None: |
| self.origin_samples = filter_long_prompt(origin_samples, tokenizer, processor, max_length) |
| else: |
| self.origin_samples = origin_samples |
|
|
| self.epoch_id = -1 |
| self.seed = seed |
| self.samples = self.origin_samples |
|
|
| def shuffle(self, new_epoch_id): |
| if self.epoch_id == new_epoch_id: |
| return |
|
|
| random.seed(self.seed + new_epoch_id) |
| permutation = list(range(len(self.samples))) |
| random.shuffle(permutation) |
| self.samples = [self.origin_samples[i] for i in permutation] |
| self.epoch_id = new_epoch_id |
|
|
| def __getitem__(self, idx): |
| return self.samples[idx] |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
|
|
| def get_minimum_num_micro_batch_size(total_lengths, max_tokens_per_gpu): |
| |
| batches = [] |
| for length in total_lengths: |
| for i in range(len(batches)): |
| if batches[i] + length <= max_tokens_per_gpu: |
| batches[i] += length |
| break |
| else: |
| batches.append(length) |
|
|
| return len(batches) |
|
|
|
|
| def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size): |
| assert len(rollout_data_ref) == dp_size |
| rollout_data = ray.get(rollout_data_ref[dp_rank].inner) |
|
|
| partition = rollout_data.pop("partition") |
| total_lengths = rollout_data["total_lengths"] |
|
|
| |
| Timer().seq_lens = total_lengths |
| rollout_data["total_lengths"] = [total_lengths[i] for i in partition] |
|
|
| return rollout_data |
|
|