Spaces:
Runtime error
Runtime error
| import json | |
| from dataset.base_dataset import BaseDataset | |
| from dataset.utils import pre_text, load_anno | |
| from dataset.video_utils import VIDEO_READER_FUNCS | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class ImageQADataset(BaseDataset): | |
| media_type = "image" | |
| def __init__(self, ann_file, transform, eos="[SEP]", mode="train", answer_list=None): | |
| super(ImageQADataset, self).__init__() | |
| assert mode in ["train", "eval"] | |
| self.mode = mode | |
| self.transform = transform | |
| self.eos = eos | |
| self.anno_list = load_anno(ann_file) | |
| if mode == "eval": | |
| self.answer_list = json.load(open(answer_list, "r")) | |
| def __len__(self): | |
| return len(self.anno_list) | |
| def get_answers_with_weights(self, raw_answers): | |
| if isinstance(raw_answers, str): | |
| raw_answers = [raw_answers] | |
| answer_weight = {} | |
| for answer in raw_answers: | |
| if answer in answer_weight.keys(): | |
| answer_weight[answer] += 1/len(raw_answers) | |
| else: | |
| answer_weight[answer] = 1/len(raw_answers) | |
| answers = list(answer_weight.keys()) | |
| weights = [answer_weight[a] for a in answers] | |
| answers = [answer + " " + self.eos for answer in answers] | |
| return answers, weights | |
| def __getitem__(self, index): | |
| ann = self.anno_list[index] | |
| image, index = self.load_and_transform_media_data(index, ann["media"]) | |
| question = pre_text(ann["question"]) | |
| if self.mode == "train": | |
| answers, weights = self.get_answers_with_weights(ann["answer"]) | |
| return image, question, answers, weights | |
| else: # self.mode == "eval": | |
| question_id = ann["question_id"] | |
| return image, question, question_id | |
| class VideoQADataset(ImageQADataset): | |
| media_type = "video" | |
| def __init__( | |
| self, ann_file, transform, eos="[SEP]", mode="train", answer_list=None, | |
| num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=1 | |
| ): | |
| super(VideoQADataset, self).__init__( | |
| ann_file, transform, eos, mode, answer_list) | |
| self.num_frames = num_frames | |
| self.video_reader_type = video_reader_type | |
| self.video_reader = VIDEO_READER_FUNCS[video_reader_type] | |
| self.sample_type = sample_type | |
| self.num_tries = num_tries | |