vlm-demo / src /data /datasets.py
tiltaf's picture
Upload 29 files
cf3d756 verified
import torch
from PIL import Image
from torch.utils.data import Dataset
from data.processors import get_image_string
import logging
class BaseDataset(Dataset):
def __init__(self, dataset, tokenizer, image_processor, mp_image_token_length, relevance_min_rating=1, image_correspondence_min_rating=1, visual_dependency_min_rating=1, formatting_min_rating=1):
self.dataset = dataset
self.tokenizer = tokenizer
self.image_processor = image_processor
self.mp_image_token_length = mp_image_token_length
self.relevance_min_rating = relevance_min_rating
self.image_correspondence_min_rating = image_correspondence_min_rating
self.visual_dependency_min_rating = visual_dependency_min_rating
self.formatting_min_rating = formatting_min_rating
self.prefix_len = self._get_prefix_len()
def __len__(self):
return len(self.dataset)
def _get_prefix_len(self):
random_string_5_letters = "xzyvd"
random_string_chat_templated = self.tokenizer.apply_chat_template([{"role": "assistant", "content": random_string_5_letters}], tokenize=False, add_special_tokens=False)
random_string_location = random_string_chat_templated.find(random_string_5_letters)
return len(self.tokenizer.encode(random_string_chat_templated[:random_string_location]))
def _get_messages(self, item, splitted_image_counts):
messages = []
for index, text in enumerate(item['texts']):
try:
if item.get('relevance_ratings') is not None and item['relevance_ratings'][index] is not None and item['relevance_ratings'][index] < self.relevance_min_rating:
continue
if item.get('image_correspondence_ratings') is not None and item['image_correspondence_ratings'][index] is not None and item['image_correspondence_ratings'][index] < self.image_correspondence_min_rating:
continue
if item.get('visual_dependency_ratings') is not None and item['visual_dependency_ratings'][index] is not None and item['visual_dependency_ratings'][index] < self.visual_dependency_min_rating:
continue
if item.get('formatting_ratings') is not None and item['formatting_ratings'][index] is not None and item['formatting_ratings'][index] < self.formatting_min_rating:
continue
except Exception as e:
logging.warning(f"Error processing item: {item}, index: {index}: {e}")
messages.append({"role": "user", "content": text['user']})
messages.append({"role": "assistant", "content": text['assistant']})
if len(messages) == 0:
return messages
# Safety check to ensure no image tokens are present in the text before adding them.
for msg in messages:
if self.tokenizer.image_token in msg["content"]:
logging.warning(f"Found and removed an image token in the {msg['role']} text before adding the image string.")
msg["content"] = msg["content"].replace(self.tokenizer.image_token, "")
if len(splitted_image_counts) > 0:
image_string = get_image_string(self.tokenizer, splitted_image_counts, self.mp_image_token_length)
messages[0]["content"] = image_string + messages[0]["content"]
return messages
def _process_images(self, images):
processed_images = []
splitted_image_counts = []
for image in images:
if isinstance(image, Image.Image):
if image.mode != 'RGB':
image = image.convert('RGB')
processed_image, splitted_image_count = self.image_processor(image)
if not hasattr(self.tokenizer, "global_image_token") and splitted_image_count[0]*splitted_image_count[1] == len(processed_image) - 1:
# If the tokenizer doesn't have a global image token, but the processor generated it, remove it
processed_image = processed_image[1:]
processed_images.append(processed_image)
splitted_image_counts.append(splitted_image_count)
else:
raise ValueError(f"Error processing image: {image}")
return processed_images, splitted_image_counts
def _prepare_inputs_and_loss_mask(self, messages):
conv_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_special_tokens=False,
return_dict=True,
)
mask = [0] * len(conv_ids["input_ids"])
# Locate each assistant turn and flip its mask to 1
cursor = 0
for msg in messages:
segment_ids = self.tokenizer.apply_chat_template(
[msg], tokenize=True, add_special_tokens=False
)
seg_len = len(segment_ids)
if msg["role"] == "assistant":
start = cursor + self.prefix_len
end = cursor + seg_len
mask[start:end] = [1] * (end - start) # attend to these tokens
cursor += seg_len
return torch.tensor(conv_ids["input_ids"]), torch.tensor(mask).to(torch.bool), torch.tensor(conv_ids["attention_mask"])
class VQADataset(BaseDataset): # Visual Question Answering Dataset
def iter_for_worker(self): # with iterable datasets, each worker gets different shards
for data in self.dataset:
yield self._process_data(data)
def __getitem__(self, idx):
item = self.dataset[idx]
return self._process_data(item)
def _process_data(self, item):
# Handle images (should be a list)
if item['images'] is None:
images_data = []
else:
images_data = item['images']
if not isinstance(images_data, list):
images_data = [images_data]
processed_images = []
splitted_image_counts = []
if images_data: # Only process if there are images
processed_images, splitted_image_counts = self._process_images(images_data)
messages = self._get_messages(item, splitted_image_counts)
if len(messages) == 0:
return None
input_ids, mask, attention_mask = self._prepare_inputs_and_loss_mask(messages)
labels = self._get_labels(input_ids, mask)
return {
"images": processed_images,
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def _get_labels(self, input_ids, mask):
labels = input_ids.clone().masked_fill(~mask, -100)
labels = labels.roll(-1) # Shift labels for causal LM
labels[-1] = -100 # Last token has no target
return labels