JustinTX's picture
Add files using upload-large-folder tool
d7b3a74 verified
import itertools
import json
import logging
import os
import random
import re
import numpy as np
import ray
try:
import pyarrow.parquet as pq
except ImportError:
pq = None
from slime.utils.types import MultimodalTypes, Sample
from .timer import Timer
__all__ = ["Dataset"]
logger = logging.getLogger(__name__)
def read_file(path):
path, row_slice = _parse_generalized_path(path)
reader = None
if not os.path.exists(path):
raise FileNotFoundError(f"Prompt dataset path '{path}' does not exist.")
if path.endswith(".jsonl"):
def jsonl_reader(p):
with open(p, encoding="utf-8") as f:
for line_num, line in enumerate(f):
line = line.strip()
if not line:
continue
try:
yield json.loads(line)
except json.JSONDecodeError as e:
print(f"JSON decode error at line {line_num}: {e}")
continue
reader = jsonl_reader(path)
elif path.endswith(".parquet"):
if pq is None:
raise ImportError("pyarrow is required for parquet support")
def parquet_reader(p):
pf = pq.ParquetFile(p)
for batch in pf.iter_batches():
yield from batch.to_pylist()
reader = parquet_reader(path)
else:
raise ValueError(f"Unsupported file format: {path}. Supported formats are .jsonl and .parquet.")
if row_slice is not None:
logger.info("read_file path=%s applying slice row_slice=%s", path, row_slice)
reader = itertools.islice(reader, row_slice.start, row_slice.stop, row_slice.step)
yield from reader
def _parse_generalized_path(s: str):
if (m := re.match(r"^(?P<real_path>.*)@\[(?P<start>-?\d*):(?P<end>-?\d*)\]$", s)) is not None:
path = m.group("real_path")
start = int(x) if (x := m.group("start")) != "" else None
end = int(x) if (x := m.group("end")) != "" else None
return path, slice(start, end)
return s, None
def filter_long_prompt(origin_samples: list[Sample], tokenizer, processor, max_length: int | None) -> list[Sample]:
if max_length is None:
return origin_samples
if not isinstance(origin_samples[0].prompt, str):
logger.warning(
"Skipping max_length check for list prompt. Set apply_chat_template=True to enable length filtering."
)
return origin_samples
if processor:
filtered_samples = []
for sample in origin_samples:
from slime.utils.processing_utils import process_vision_info
multimodal_inputs = process_vision_info(sample.prompt, processor)
processor_output = processor(text=sample.prompt, **multimodal_inputs)
input_ids = processor_output["input_ids"][0]
if len(input_ids) <= max_length:
filtered_samples.append(sample)
else:
prompts = [sample.prompt for sample in origin_samples]
input_ids_list = tokenizer(prompts, add_special_tokens=False)["input_ids"]
filtered_samples = [
sample
for sample, input_ids in zip(origin_samples, input_ids_list, strict=True)
if len(input_ids) <= max_length
]
logger.info(f"Filtered {len(origin_samples) - len(filtered_samples)} samples longer than max_length={max_length}.")
return filtered_samples
def _build_messages(data: dict, prompt_key: str, as_conversation: bool, multimodal_keys: dict = None):
prompt = data.get(prompt_key)
if isinstance(prompt, str):
# If prompt is a string and we don't apply chat template, return the prompt as is.
if not as_conversation:
return prompt
else:
prompt = [{"role": "user", "content": prompt}]
if multimodal_keys:
# Build mapping: placeholder -> (MultimodalType, content_list)
multimodals = {}
for type_name, data_key in multimodal_keys.items():
mt = MultimodalTypes.get(type_name)
if mt:
multimodal_data = data.get(data_key)
if multimodal_data is not None:
multimodals[mt.placeholder] = (mt, list(multimodal_data))
pattern = "(" + "|".join(re.escape(p) for p in multimodals.keys()) + ")"
for message in prompt:
if isinstance(message["content"], str):
content_list = []
for segment in re.split(pattern, message["content"]):
if not segment:
continue
if segment in multimodals:
mt, content = multimodals[segment]
assert len(content) > 0, (
f"Not enough {mt.name} data: more '{mt.placeholder}' placeholders in prompt "
f"than {mt.name}s provided in data"
)
content_list.append({"type": mt.name, mt.name: content.pop(0)})
else:
content_list.append({"type": "text", "text": segment})
message["content"] = content_list
elif isinstance(message["content"], list):
# TODO: handle more general cases. where message['content'] is a dict and contains multiple types of content.
# e.g.
# "content": [
# {
# "type": "image",
# "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
# },
# {"type": "text", "text": "Describe this image."},
# ],
logger.warning("message['content'] is a list of dicts, no processing will be done.")
continue
else:
raise ValueError(
f"Unsupported content type: {type(message['content'])}, expected str or list of dicts"
)
for placeholder, (mt, remaining) in multimodals.items():
assert len(remaining) == 0, (
f"Multimodal data count mismatch: {len(remaining)} more {mt.name}(s)"
f"than '{placeholder}' placeholders in prompt"
)
return prompt
class Dataset:
def __init__(
self,
path,
tokenizer,
processor,
max_length,
*,
prompt_key="text",
multimodal_keys=None,
label_key=None,
tool_key=None,
metadata_key="metadata",
seed=42,
apply_chat_template=False,
apply_chat_template_kwargs=None,
):
origin_samples = []
for data in read_file(path):
# Both chat templates and multimodal inputs require conversation format (list of message dicts)
as_conversation = apply_chat_template or (multimodal_keys is not None)
prompt = _build_messages(data, prompt_key, as_conversation, multimodal_keys)
metadata = data.get(metadata_key) or {}
tools = None
if tool_key is not None and tool_key in data:
tools = data[tool_key]
if isinstance(tools, str):
tools = json.loads(tools)
elif isinstance(tools, np.ndarray):
tools = tools.tolist()
assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead"
metadata["tools"] = tools
if apply_chat_template:
output_prompt = tokenizer.apply_chat_template(
prompt,
tools=tools,
tokenize=False,
add_generation_prompt=True,
**(apply_chat_template_kwargs or {}),
)
else:
output_prompt = prompt
if processor:
from slime.utils.processing_utils import process_vision_info
assert isinstance(
prompt, list
), f"prompt must be a list when processor is not None, got {type(prompt)} instead"
multimodal_inputs = process_vision_info(prompt, processor)
else:
multimodal_inputs = None
origin_samples.append(
Sample(
prompt=output_prompt,
label=data[label_key] if label_key is not None else None,
metadata=metadata,
multimodal_inputs=multimodal_inputs,
)
)
if max_length is not None:
self.origin_samples = filter_long_prompt(origin_samples, tokenizer, processor, max_length)
else:
self.origin_samples = origin_samples
self.epoch_id = -1
self.seed = seed
self.samples = self.origin_samples
def shuffle(self, new_epoch_id):
if self.epoch_id == new_epoch_id:
return
random.seed(self.seed + new_epoch_id)
permutation = list(range(len(self.samples)))
random.shuffle(permutation)
self.samples = [self.origin_samples[i] for i in permutation]
self.epoch_id = new_epoch_id
def __getitem__(self, idx):
return self.samples[idx]
def __len__(self):
return len(self.samples)
def get_minimum_num_micro_batch_size(total_lengths, max_tokens_per_gpu):
# use first fit to get the number of micro batches
batches = []
for length in total_lengths:
for i in range(len(batches)):
if batches[i] + length <= max_tokens_per_gpu:
batches[i] += length
break
else:
batches.append(length)
return len(batches)
def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size):
assert len(rollout_data_ref) == dp_size
rollout_data = ray.get(rollout_data_ref[dp_rank].inner)
partition = rollout_data.pop("partition")
total_lengths = rollout_data["total_lengths"]
# save the seqlen of the whole rollout batch
Timer().seq_lens = total_lengths
rollout_data["total_lengths"] = [total_lengths[i] for i in partition]
return rollout_data