Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| import os | |
| import random | |
| import re | |
| import time | |
| import traceback | |
| import warnings | |
| from io import BytesIO | |
| import pandas as pd | |
| import h5py | |
| import numpy as np | |
| import torch | |
| from icecream import ic | |
| from PIL import Image, ImageFile | |
| from torch.utils.data import Dataset, Subset | |
| from utils import get_args | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| ImageFile.MAX_IMAGE_PIXELS = None | |
| Image.MAX_IMAGE_PIXELS = None | |
| logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |
| datefmt='%m/%d/%Y %H:%M:%S', | |
| level=logging.INFO) | |
| warnings.filterwarnings("ignore") | |
| logger = logging.getLogger(__name__) | |
| def load_jsonl(filename): | |
| with open(filename, "r", encoding="utf-8") as f: | |
| return [json.loads(l.strip("\n")) for l in f.readlines()] | |
| class MultiModalDataset(Dataset): | |
| """MultiModal dataset""" | |
| def __init__(self, videopath, text, tokenizer, processor, | |
| max_length=2048, | |
| media_tokens=['<image>', '<|video|>'], loss_objective = 'sequential'): | |
| args = get_args() | |
| self.loss_objective = loss_objective | |
| if 'sequential' not in self.loss_objective: | |
| raise NotImplementedError('dataset loader not implemented for other loss objectives') | |
| self.videopath = videopath | |
| self.text = text | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| self.processor = processor | |
| self.media_tokens = {k: -int(i+1) for i, k in enumerate(media_tokens)} | |
| self.media_lengths = {'<image>': 1+64,'<|video|>': 1+64} | |
| print("num_media_token: ", self.media_lengths) | |
| self.bucket = {} | |
| def __len__(self): | |
| return 1 | |
| def __getitem__(self, index): | |
| videopath = self.videopath | |
| caption = self.text | |
| video_input = self.processor(videos=[videopath], num_frames=32, return_tensors='pt') # video_pixel_values | |
| text_input = self._extract_text_token_from_conversation(caption, self.max_length, index) | |
| item = {'video': video_input, 'text': text_input, 'videopath': videopath, 'caption': caption} | |
| return item | |
| def _extract_text_token_from_conversation(self, data, max_length, index): | |
| # output enc_chunk | |
| enc_chunk = [] | |
| if self.tokenizer.bos_token_id > 0: | |
| prompt_chunk = [self.tokenizer.bos_token_id] | |
| else: | |
| prompt_chunk = [] | |
| # conversation = data["completion"] | |
| conversation = data | |
| # For Text only data | |
| if all([media_token not in conversation for media_token in self.media_tokens.keys()]): | |
| pattern = '|'.join(map(re.escape, ['AI: ', '\nHuman: '])) | |
| chunk_strs = re.split(f'({pattern})', conversation) | |
| prompt_length = -1 | |
| stop_flag = False | |
| for idx, chunk_str in enumerate(chunk_strs): | |
| if idx == 0: | |
| enc_chunk = prompt_chunk + \ | |
| self.tokenizer(chunk_str, add_special_tokens=False)[ | |
| 'input_ids'] | |
| enc_length = len(enc_chunk) | |
| label_chunk = [0] * enc_length | |
| else: | |
| if chunk_strs[idx-1] == 'AI: ': | |
| curr_chunk = self.tokenizer( | |
| chunk_str, add_special_tokens=False)['input_ids'] | |
| if enc_length + len(curr_chunk) >= max_length: | |
| curr_chunk = curr_chunk[:max_length-enc_length] | |
| stop_flag = True | |
| curr_chunk += [self.tokenizer.eos_token_id] | |
| enc_length += len(curr_chunk) | |
| enc_chunk += curr_chunk | |
| label_chunk += [1] * len(curr_chunk) | |
| else: | |
| curr_chunk = self.tokenizer( | |
| chunk_str, add_special_tokens=False)['input_ids'] | |
| if enc_length + len(curr_chunk) >= max_length + 1: | |
| curr_chunk = curr_chunk[:max_length+1-enc_length] | |
| stop_flag = True | |
| enc_length += len(curr_chunk) | |
| enc_chunk += curr_chunk | |
| label_chunk += [0] * len(curr_chunk) | |
| if stop_flag: | |
| break | |
| # For Image-Text Data | |
| else: | |
| enc_length = 0 | |
| prompt_length = -2 | |
| pattern = '|'.join( | |
| map(re.escape, list(self.media_tokens.keys()) + ['AI: ', '\nHuman: '])) | |
| chunk_strs = re.split(f'({pattern})', conversation) | |
| chunk_strs = [x for x in chunk_strs if len(x) > 0] | |
| for idx, chunk_str in enumerate(chunk_strs): | |
| if enc_length >= max_length + 1: | |
| break | |
| if idx == 0: | |
| enc_chunk = prompt_chunk + \ | |
| self.tokenizer(chunk_str, add_special_tokens=False)[ | |
| 'input_ids'] | |
| enc_length = len(enc_chunk) | |
| label_chunk = [0] * enc_length | |
| else: | |
| if chunk_str in self.media_tokens: | |
| # [CLS] + 256 + [EOS] | |
| if enc_length + self.media_lengths[chunk_str] > max_length + 1: | |
| break | |
| else: | |
| enc_chunk += [self.media_tokens[chunk_str] | |
| ] * self.media_lengths[chunk_str] | |
| enc_length += self.media_lengths[chunk_str] | |
| label_chunk += [0] * self.media_lengths[chunk_str] | |
| else: | |
| if chunk_strs[idx-1] == 'AI: ': | |
| curr_chunk = self.tokenizer( | |
| chunk_str, add_special_tokens=False)['input_ids'] | |
| if enc_length + len(curr_chunk) >= max_length: | |
| curr_chunk = curr_chunk[:max_length-enc_length] | |
| curr_chunk += [self.tokenizer.eos_token_id] | |
| enc_length += len(curr_chunk) | |
| enc_chunk += curr_chunk | |
| label_chunk += [1] * len(curr_chunk) | |
| else: | |
| curr_chunk = self.tokenizer( | |
| chunk_str, add_special_tokens=False)['input_ids'] | |
| if enc_length + len(curr_chunk) >= max_length + 1: | |
| curr_chunk = curr_chunk[:max_length + | |
| 1-enc_length] | |
| enc_length += len(curr_chunk) | |
| enc_chunk += curr_chunk | |
| label_chunk += [0] * len(curr_chunk) | |
| if enc_length < max_length + 1: | |
| padding_chunk = [self.tokenizer.pad_token_id] * \ | |
| (max_length + 1 - enc_length) | |
| padding_length = len(padding_chunk) | |
| label_chunk += [0] * (max_length + 1 - enc_length) | |
| enc_chunk = enc_chunk + padding_chunk | |
| else: | |
| padding_length = 0 | |
| assert enc_length + padding_length == max_length + \ | |
| 1, (index, prompt_length, enc_length, | |
| padding_length, max_length + 1) | |
| assert len(label_chunk) == max_length + \ | |
| 1, (len(label_chunk), max_length + 1) | |
| non_padding_mask = [1 if i < enc_length - | |
| 1 else 0 for i in range(max_length)] | |
| enc_chunk = torch.tensor(enc_chunk).long() | |
| non_padding_mask = torch.tensor(non_padding_mask).long() | |
| prompt_mask = torch.tensor(label_chunk)[1:].long() | |
| prompt_length = torch.tensor([prompt_length]).long() | |
| # Create loss mask | |
| if all([media_token not in conversation for media_token in self.media_tokens.keys()]): | |
| non_media_mask = torch.ones_like(non_padding_mask).long() | |
| else: | |
| tmp_enc_chunk = enc_chunk.clone() | |
| tmp_enc_chunk[tmp_enc_chunk >= 0] = 1 | |
| tmp_enc_chunk[tmp_enc_chunk < 0] = 0 | |
| non_media_mask = torch.tensor(tmp_enc_chunk).long() | |
| non_media_mask = non_media_mask[1:].long() | |
| return {'input_ids': enc_chunk, "prompt_length": prompt_length, 'seq_length': enc_length, | |
| "non_padding_mask": non_padding_mask, 'non_media_mask': non_media_mask, 'prompt_mask': prompt_mask} |