import torch from PIL import Image from torch.utils.data import Dataset from data.processors import get_image_string import logging class BaseDataset(Dataset): def __init__(self, dataset, tokenizer, image_processor, mp_image_token_length, relevance_min_rating=1, image_correspondence_min_rating=1, visual_dependency_min_rating=1, formatting_min_rating=1): self.dataset = dataset self.tokenizer = tokenizer self.image_processor = image_processor self.mp_image_token_length = mp_image_token_length self.relevance_min_rating = relevance_min_rating self.image_correspondence_min_rating = image_correspondence_min_rating self.visual_dependency_min_rating = visual_dependency_min_rating self.formatting_min_rating = formatting_min_rating self.prefix_len = self._get_prefix_len() def __len__(self): return len(self.dataset) def _get_prefix_len(self): random_string_5_letters = "xzyvd" random_string_chat_templated = self.tokenizer.apply_chat_template([{"role": "assistant", "content": random_string_5_letters}], tokenize=False, add_special_tokens=False) random_string_location = random_string_chat_templated.find(random_string_5_letters) return len(self.tokenizer.encode(random_string_chat_templated[:random_string_location])) def _get_messages(self, item, splitted_image_counts): messages = [] for index, text in enumerate(item['texts']): try: if item.get('relevance_ratings') is not None and item['relevance_ratings'][index] is not None and item['relevance_ratings'][index] < self.relevance_min_rating: continue if item.get('image_correspondence_ratings') is not None and item['image_correspondence_ratings'][index] is not None and item['image_correspondence_ratings'][index] < self.image_correspondence_min_rating: continue if item.get('visual_dependency_ratings') is not None and item['visual_dependency_ratings'][index] is not None and item['visual_dependency_ratings'][index] < self.visual_dependency_min_rating: continue if item.get('formatting_ratings') is not None and item['formatting_ratings'][index] is not None and item['formatting_ratings'][index] < self.formatting_min_rating: continue except Exception as e: logging.warning(f"Error processing item: {item}, index: {index}: {e}") messages.append({"role": "user", "content": text['user']}) messages.append({"role": "assistant", "content": text['assistant']}) if len(messages) == 0: return messages # Safety check to ensure no image tokens are present in the text before adding them. for msg in messages: if self.tokenizer.image_token in msg["content"]: logging.warning(f"Found and removed an image token in the {msg['role']} text before adding the image string.") msg["content"] = msg["content"].replace(self.tokenizer.image_token, "") if len(splitted_image_counts) > 0: image_string = get_image_string(self.tokenizer, splitted_image_counts, self.mp_image_token_length) messages[0]["content"] = image_string + messages[0]["content"] return messages def _process_images(self, images): processed_images = [] splitted_image_counts = [] for image in images: if isinstance(image, Image.Image): if image.mode != 'RGB': image = image.convert('RGB') processed_image, splitted_image_count = self.image_processor(image) if not hasattr(self.tokenizer, "global_image_token") and splitted_image_count[0]*splitted_image_count[1] == len(processed_image) - 1: # If the tokenizer doesn't have a global image token, but the processor generated it, remove it processed_image = processed_image[1:] processed_images.append(processed_image) splitted_image_counts.append(splitted_image_count) else: raise ValueError(f"Error processing image: {image}") return processed_images, splitted_image_counts def _prepare_inputs_and_loss_mask(self, messages): conv_ids = self.tokenizer.apply_chat_template( messages, tokenize=True, add_special_tokens=False, return_dict=True, ) mask = [0] * len(conv_ids["input_ids"]) # Locate each assistant turn and flip its mask to 1 cursor = 0 for msg in messages: segment_ids = self.tokenizer.apply_chat_template( [msg], tokenize=True, add_special_tokens=False ) seg_len = len(segment_ids) if msg["role"] == "assistant": start = cursor + self.prefix_len end = cursor + seg_len mask[start:end] = [1] * (end - start) # attend to these tokens cursor += seg_len return torch.tensor(conv_ids["input_ids"]), torch.tensor(mask).to(torch.bool), torch.tensor(conv_ids["attention_mask"]) class VQADataset(BaseDataset): # Visual Question Answering Dataset def iter_for_worker(self): # with iterable datasets, each worker gets different shards for data in self.dataset: yield self._process_data(data) def __getitem__(self, idx): item = self.dataset[idx] return self._process_data(item) def _process_data(self, item): # Handle images (should be a list) if item['images'] is None: images_data = [] else: images_data = item['images'] if not isinstance(images_data, list): images_data = [images_data] processed_images = [] splitted_image_counts = [] if images_data: # Only process if there are images processed_images, splitted_image_counts = self._process_images(images_data) messages = self._get_messages(item, splitted_image_counts) if len(messages) == 0: return None input_ids, mask, attention_mask = self._prepare_inputs_and_loss_mask(messages) labels = self._get_labels(input_ids, mask) return { "images": processed_images, "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } def _get_labels(self, input_ids, mask): labels = input_ids.clone().masked_fill(~mask, -100) labels = labels.roll(-1) # Shift labels for causal LM labels[-1] = -100 # Last token has no target return labels