| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| import os |
| from collections import defaultdict |
| from io import BytesIO |
| from typing import Any, Dict, List, Optional, Union |
|
|
| import numpy as np |
| import torch |
| from datasets import load_dataset |
| from jinja2 import Template |
| from PIL import Image |
| from PIL.Image import Image as ImageObject |
| from torch.utils.data import Dataset |
| from transformers import PreTrainedTokenizer, ProcessorMixin |
|
|
| from ..models.transformers.qwen2_vl import get_rope_index |
| from . import torch_functional as VF |
|
|
|
|
| def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]: |
| tensors = defaultdict(list) |
| non_tensors = defaultdict(list) |
| for feature in features: |
| for key, value in feature.items(): |
| if isinstance(value, torch.Tensor): |
| tensors[key].append(value) |
| else: |
| non_tensors[key].append(value) |
|
|
| for key, value in tensors.items(): |
| tensors[key] = torch.stack(value, dim=0) |
|
|
| for key, value in non_tensors.items(): |
| non_tensors[key] = np.array(value, dtype=object) |
|
|
| return {**tensors, **non_tensors} |
|
|
|
|
| def process_image( |
| image: Union[Dict[str, Any], ImageObject, str], min_pixels: Optional[int], max_pixels: Optional[int] |
| ) -> ImageObject: |
| if isinstance(image, str): |
| image = Image.open(image) |
| elif isinstance(image, dict): |
| image = Image.open(BytesIO(image["bytes"])) |
| elif isinstance(image, bytes): |
| image = Image.open(BytesIO(image)) |
|
|
| image.load() |
| if max_pixels is not None and (image.width * image.height) > max_pixels: |
| resize_factor = math.sqrt(max_pixels / (image.width * image.height)) |
| width, height = int(image.width * resize_factor), int(image.height * resize_factor) |
| image = image.resize((width, height)) |
|
|
| if min_pixels is not None and (image.width * image.height) < min_pixels: |
| resize_factor = math.sqrt(min_pixels / (image.width * image.height)) |
| width, height = int(image.width * resize_factor), int(image.height * resize_factor) |
| image = image.resize((width, height)) |
|
|
| if image.mode != "RGB": |
| image = image.convert("RGB") |
|
|
| return image |
|
|
|
|
| class RLHFDataset(Dataset): |
| """ |
| We assume the dataset contains a column that contains prompts and other information |
| """ |
|
|
| def __init__( |
| self, |
| data_path: str, |
| tokenizer: PreTrainedTokenizer, |
| processor: Optional[ProcessorMixin], |
| prompt_key: str = "prompt", |
| answer_key: str = "answer", |
| image_key: str = "images", |
| image_dir: Optional[str] = None, |
| max_prompt_length: int = 1024, |
| truncation: str = "error", |
| format_prompt: Optional[str] = None, |
| min_pixels: Optional[int] = None, |
| max_pixels: Optional[int] = None, |
| filter_overlong_prompts: bool = True, |
| ): |
| self.tokenizer = tokenizer |
| self.processor = processor |
| self.prompt_key = prompt_key |
| self.answer_key = answer_key |
| self.image_key = image_key |
| self.image_dir = image_dir |
| self.max_prompt_length = max_prompt_length |
| self.truncation = truncation |
| self.min_pixels = min_pixels |
| self.max_pixels = max_pixels |
| self.filter_overlong_prompts = filter_overlong_prompts |
|
|
| if "@" in data_path: |
| data_path, data_split = data_path.split("@") |
| else: |
| data_split = "train" |
|
|
| if os.path.isdir(data_path): |
| |
| file_type = os.path.splitext(os.listdir(data_path)[0])[-1][1:].replace("jsonl", "json") |
| |
| |
| self.dataset = load_dataset(file_type, data_dir=data_path, split=data_split) |
| elif os.path.isfile(data_path): |
| file_type = os.path.splitext(data_path)[-1][1:].replace("jsonl", "json") |
| self.dataset = load_dataset(file_type, data_files=data_path, split=data_split) |
| else: |
| |
| self.dataset = load_dataset(data_path, split=data_split) |
|
|
| self.format_prompt = None |
| if format_prompt: |
| with open(format_prompt, encoding="utf-8") as f: |
| self.format_prompt = f.read() |
|
|
| if self.filter_overlong_prompts: |
| self.dataset = self.dataset.filter(self._filter_overlong_prompts, desc="Filtering overlong prompts") |
|
|
| def _build_messages(self, example: Dict[str, Any]) -> List[Dict[str, Any]]: |
| prompt_str: str = example[self.prompt_key] |
| if self.format_prompt: |
| format_prompt = Template(self.format_prompt.strip()) |
| prompt_str = format_prompt.render(content=prompt_str) |
|
|
| if self.image_key in example: |
| |
| content_list = [] |
| for i, content in enumerate(prompt_str.split("<image>")): |
| if i != 0: |
| content_list.append({"type": "image"}) |
|
|
| if content: |
| content_list.append({"type": "text", "text": content}) |
|
|
| return [{"role": "user", "content": content_list}] |
| else: |
| return [{"role": "user", "content": prompt_str}] |
|
|
| def _filter_overlong_prompts(self, example: Dict[str, Any]) -> bool: |
| messages = self._build_messages(example) |
| processing_class = self.processor if self.processor is not None else self.tokenizer |
| return ( |
| len(processing_class.apply_chat_template(messages, add_generation_prompt=True)) <= self.max_prompt_length |
| ) |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, index): |
| example: dict = self.dataset[index] |
| messages = self._build_messages(example) |
|
|
| if self.image_key in example: |
| prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
| images = example.pop(self.image_key) |
| if self.image_dir is not None and len(images) != 0 and isinstance(images[0], str): |
| images = [os.path.join(self.image_dir, image) for image in images] |
|
|
| resized_images = [ |
| process_image(image, min_pixels=self.min_pixels, max_pixels=self.max_pixels) for image in images |
| ] |
| model_inputs = self.processor(resized_images, [prompt], add_special_tokens=False, return_tensors="pt") |
| input_ids = model_inputs.pop("input_ids")[0] |
| attention_mask = model_inputs.pop("attention_mask")[0] |
| example["multi_modal_data"] = {"images": images} |
| else: |
| prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
| model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt") |
| input_ids = model_inputs.pop("input_ids")[0] |
| attention_mask = model_inputs.pop("attention_mask")[0] |
|
|
| if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__: |
| |
| position_ids = get_rope_index( |
| self.processor, |
| input_ids=input_ids, |
| image_grid_thw=model_inputs.get("image_grid_thw"), |
| attention_mask=attention_mask, |
| ) |
| else: |
| position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) |
|
|
| input_ids, attention_mask, position_ids = VF.postprocess_data( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| max_length=self.max_prompt_length, |
| pad_token_id=self.tokenizer.pad_token_id, |
| left_pad=True, |
| truncation=self.truncation, |
| ) |
| raw_prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) |
| if len(raw_prompt_ids) > self.max_prompt_length: |
| if self.truncation == "left": |
| raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :] |
| elif self.truncation == "right": |
| raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length] |
| elif self.truncation == "error": |
| raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.") |
|
|
| example["input_ids"] = input_ids |
| example["attention_mask"] = attention_mask |
| example["position_ids"] = position_ids |
| example["raw_prompt_ids"] = raw_prompt_ids |
| example["ground_truth"] = example.pop(self.answer_key) |
| return example |
|
|