Spaces:
Configuration error
Configuration error
| import json | |
| import os | |
| import re | |
| from torch.utils.data import Dataset | |
| def prompt_processor(prompt): | |
| if prompt.startswith('OCR tokens: '): | |
| pattern = r"Question: (.*?) Short answer:" | |
| match = re.search(pattern, prompt, re.DOTALL) | |
| question = match.group(1) | |
| elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: | |
| if prompt.startswith('Reference OCR token:'): | |
| question = prompt.split('\n')[1] | |
| else: | |
| question = prompt.split('\n')[0] | |
| elif len(prompt.split('\n')) == 2: | |
| question = prompt.split('\n')[0] | |
| else: | |
| assert False | |
| return question.lower() | |
| class textVQADataset(Dataset): | |
| def __init__( | |
| self, | |
| image_dir="./downloads/TextVQA/train_images", | |
| ann_path="./downloads/TextVQA/TextVQA_0.5.1_val.json", | |
| ): | |
| self.data = json.load(open(ann_path, "r"))["data"] | |
| self.image_dir = image_dir | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| question = self.data[idx]['question'] | |
| answers = self.data[idx]['answers'] | |
| img_id = self.data[idx]['image_id'] | |
| qid = self.data[idx]['question_id'] | |
| img_path = os.path.join(self.image_dir, f"{img_id}.jpg") | |
| item = { | |
| "question_id": qid, | |
| "image_path": img_path, | |
| "question": question, | |
| "gt_answers": answers | |
| } | |
| return item | |
| class docVQADataset(Dataset): | |
| def __init__( | |
| self, | |
| image_dir= "./downloads/DocVQA/spdocvqa_images", | |
| ann_path= "./downloads/DocVQA/val_v1.0_withQT.json", | |
| ocr_token_path=None | |
| ): | |
| self.data = json.load(open(ann_path, "r"))["data"] | |
| self.image_dir = image_dir | |
| self.ann_path = ann_path | |
| if ocr_token_path: | |
| self.ocr_token_data = {item['image_id']: item for item in json.load(open(ocr_token_path, "r"))["data"]} | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| question_id = self.data[idx]['questionId'] | |
| relative_img_path = self.data[idx]['image'] | |
| corrected_relative_img_path = relative_img_path.replace("documents", "images") | |
| img_path = os.path.join(self.image_dir, corrected_relative_img_path) | |
| question = self.data[idx]['question'] | |
| answers = self.data[idx]['answers'] | |
| question_type = self.data[idx]['question_types'] | |
| return { | |
| "question_id": question_id, | |
| "image_path": img_path, | |
| "question": question, | |
| "gt_answers": answers, | |
| 'question_type': question_type, | |
| } | |
| class docVQATESTDataset(Dataset): | |
| def __init__( | |
| self, | |
| image_dir= "./downloads/DocVQA/spdocvqa_images", | |
| ann_path= "./downloads/DocVQA/test_v1.0.json", | |
| ocr_token_path=None | |
| ): | |
| self.data = json.load(open(ann_path, "r"))["data"] | |
| self.image_dir = image_dir | |
| self.ann_path = ann_path | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| question_id = self.data[idx]['questionId'] | |
| relative_img_path = self.data[idx]['image'] | |
| corrected_relative_img_path = relative_img_path.replace("documents", "images") | |
| img_path = os.path.join(self.image_dir, corrected_relative_img_path) | |
| question = self.data[idx]['question'] | |
| return { | |
| "question_id": question_id, | |
| "image_path": img_path, | |
| "question": question, | |
| "gt_answers": "", | |
| 'question_type': "", | |
| } | |