Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import json | |
| import sqlite3 | |
| import random | |
| from os.path import basename | |
| import numpy as np | |
| import datetime | |
| from dataset.base_dataset import ImageVideoBaseDataset | |
| from dataset.video_utils import VIDEO_READER_FUNCS | |
| logger = logging.getLogger(__name__) | |
| IMAGE_TOKEN="<image>" | |
| class ITImgTrainDataset(ImageVideoBaseDataset): | |
| media_type = "image" | |
| def __init__( | |
| self, ann_file, transform, | |
| system="", role=("Human", "Assistant"), | |
| mm_alone=True, | |
| add_second_msg=True, | |
| start_token="<Image>", end_token="</Image>", | |
| random_shuffle=True, # if True, shuffle the QA list ##xl:????? why need random shuffle | |
| begin_signal=None, | |
| end_signal=None, | |
| clip_transform=False, | |
| skip_short_sample=False, | |
| ): | |
| super().__init__() | |
| self.mm_alone = mm_alone | |
| self.clip_transform = clip_transform | |
| if len(ann_file) == 3 and ann_file[2] == "video": | |
| self.media_type = "video" | |
| else: | |
| self.media_type = "image" | |
| self.label_file, self.data_root = ann_file[:2] | |
| logger.info('Load json file') | |
| with open(self.label_file, 'r') as f: | |
| self.anno = json.load(f) | |
| self.num_examples = len(self.anno) | |
| self.transform = transform | |
| annos = [] | |
| for ann in self.anno: | |
| filename = ann['video'] if 'video' in ann else ann['image'] | |
| if self.media_type =='video' and "webvid" in self.data_root: | |
| video_id, extension = os.path.splitext(os.path.basename(filename)) | |
| if video_id not in self.keys_indexfile: | |
| pass | |
| else: | |
| annos.append(ann) | |
| else: | |
| if filename is None or filename=="None": | |
| pass | |
| else: | |
| if os.path.exists(os.path.join(self.data_root, filename)): | |
| annos.append(ann) | |
| else: | |
| ... | |
| self.anno = annos | |
| self.num_examples = len(self.anno) | |
| # prompt parameters | |
| if system: | |
| assert system[-1] == " ", "' ' should be add in the end of system, thus '###' will be tokenized into one token." | |
| # currently not support add start_token and end_token in the system, since the msg should be added properly | |
| self.begin_signal = [begin_signal for _ in role] if isinstance(begin_signal, str) else begin_signal | |
| self.end_signal = [end_signal for _ in role] if isinstance(end_signal, str) else end_signal | |
| self.start_token = start_token | |
| self.end_token = end_token | |
| self.system = system | |
| self.role = role | |
| self.random_shuffle = random_shuffle | |
| # instruction location and number | |
| logger.info(f"Random shuffle: {self.random_shuffle}") | |
| def get_anno(self, index): | |
| filename = self.anno[index][self.media_type] | |
| qa = self.anno[index]["QA"] | |
| if "start" in self.anno[index] and "end" in self.anno[index]: | |
| anno = { | |
| "image": os.path.join(self.data_root, filename), "qa": qa, | |
| "start": self.anno[index]["start"], "end": self.anno[index]["end"], | |
| } | |
| else: | |
| anno = {"image": os.path.join(self.data_root, filename), "qa": qa} | |
| return anno | |
| def __len__(self): | |
| return self.num_examples | |
| def process_qa(self, qa, msg=""): | |
| cur_instruction = "" | |
| # randomly shuffle qa for conversation | |
| if self.random_shuffle and len(qa) > 1: | |
| random.shuffle(qa) | |
| if "i" in qa[0].keys() and qa[0]["i"] != "": | |
| cur_instruction = qa[0]["i"] + self.end_signal[0] | |
| conversation = self.system | |
| # add instruction as system message | |
| if cur_instruction: | |
| conversation += cur_instruction | |
| # rstrip() for the extra " " in msg | |
| if self.mm_alone: | |
| conversation += ( | |
| self.begin_signal[0] + self.role[0] + | |
| self.start_token + self.end_token + msg.rstrip() + self.end_signal[0] | |
| ) | |
| for i, sentence in enumerate(qa): | |
| q = self.start_token + self.end_token+"\n"+ qa[0]["q"] if (not self.mm_alone) and (i == 0) else sentence["q"] | |
| a = sentence["a"] | |
| if q != "": | |
| conversation += (self.begin_signal[0] + self.role[0] + q + self.end_signal[1]) | |
| else: | |
| # no question, often in caption dataset | |
| pass | |
| conversation += (self.begin_signal[0] + self.role[1] + a + self.end_signal[1]) | |
| if cur_instruction: | |
| cur_instruction += qa[0]["q"] | |
| return conversation, cur_instruction.strip() | |
| def __getitem__(self, index): | |
| try: | |
| ann = self.get_anno(index) | |
| image, index = self.load_and_transform_media_data_image(index, ann["image"], clip_transform=self.clip_transform) | |
| conversation, instruction = self.process_qa(ann["qa"]) | |
| return image, conversation, instruction, index | |
| except Exception as e: | |
| logger.warning(f"Caught exception {e} when loading image {ann['image']}") | |
| index = np.random.randint(0, len(self)) | |
| return self.__getitem__(index) | |
| class ITVidTrainDataset(ITImgTrainDataset): | |
| media_type = "video" | |
| def __init__( | |
| self, ann_file, transform, | |
| num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3, | |
| mm_alone=True, | |
| system="", role=("Human", "Assistant"), | |
| start_token="<Video>", end_token="</Video>", | |
| add_second_msg=True, | |
| random_shuffle=True, | |
| begin_signal=None, | |
| end_signal=None, | |
| clip_transform=False, | |
| skip_short_sample=False, | |
| ): | |
| # "id index file for webvid" | |
| if "webvid" in ann_file[1]: | |
| with open("/mnt/bn/dq-storage-ckpt/xulin/datasets/videos/webvid_10m/keys_indexfile.json") as f: | |
| self.keys_indexfile = json.load(f) # the correponding index file for each webvid id | |
| super().__init__( | |
| ann_file, transform, | |
| system=system, role=role, | |
| mm_alone=mm_alone, | |
| start_token=start_token, end_token=end_token, | |
| random_shuffle=random_shuffle, | |
| begin_signal=begin_signal, | |
| end_signal=end_signal, | |
| clip_transform=clip_transform, | |
| skip_short_sample=skip_short_sample, | |
| ) | |
| self.num_frames = num_frames | |
| self.video_reader_type = video_reader_type | |
| self.video_reader = VIDEO_READER_FUNCS[video_reader_type] | |
| self.sample_type = sample_type | |
| self.num_tries = num_tries | |
| self.add_second_msg = add_second_msg | |
| logger.info(f"Use {video_reader_type} for data in {ann_file}") | |
| if add_second_msg: | |
| logger.info(f"Add second message: The video contains X frames sampled at T seconds.") | |
| def __getitem__(self, index): | |
| try: | |
| ann = self.get_anno(index) | |
| msg = "" | |
| clip = None | |
| if "start" in ann and "end" in ann: | |
| clip = [ann["start"], ann["end"]] | |
| video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip, clip_transform=self.clip_transform) | |
| if self.add_second_msg: | |
| # " " should be added in the start and end | |
| msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. " | |
| conversation, instruction = self.process_qa(ann["qa"], msg) | |
| return video, conversation, instruction, index | |
| except Exception as e: | |
| logger.warning(f"Caught exception {e} when loading video {ann['image']}") | |
| index = np.random.randint(0, len(self)) | |
| return self.__getitem__(index) |