# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from PIL.Image import Image from verl.utils.dataset import RLHFDataset from verl.utils.tokenizer import get_processor, get_tokenizer def test_image_dataset(): tokenizer = get_tokenizer("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True) processor = get_processor("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True) dataset = RLHFDataset( data_path="hiyouga/geometry3k@test", tokenizer=tokenizer, processor=processor, prompt_key="problem", answer_key="answer", image_key="images", max_prompt_length=16, truncation="right", filter_overlong_prompts=False, ) token_ids = [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655] assert set(dataset[0].keys()) == { "problem", "ground_truth", "input_ids", "attention_mask", "position_ids", "raw_prompt_ids", "multi_modal_data", } assert dataset[0]["problem"] == ( "Chords $\\overline{A C}$ and $\\overline{D F}$ are equidistant from the center. " "If the radius of $\\odot G$ is 26 find $A C$" ) assert dataset[0]["ground_truth"] == "48" assert torch.all(dataset[0]["input_ids"] == torch.tensor(token_ids)) assert torch.all(dataset[0]["attention_mask"] == torch.ones(16)) assert torch.all(dataset[0]["position_ids"] == torch.arange(16).unsqueeze(0).expand(3, -1)) assert list(dataset[0]["position_ids"].size()) == [3, 16] # avoid fake positive caused by broadcasting assert dataset[0]["raw_prompt_ids"] == token_ids assert isinstance(dataset[0]["multi_modal_data"]["images"][0], Image) if __name__ == "__main__": test_image_dataset()