| import torch |
| import numpy as np |
| from PIL import Image |
| from typing import Any |
| from ast import literal_eval |
| from torch.utils.data import Dataset |
|
|
| import paths |
| from utils_ctc import sample_text_to_seq |
|
|
| |
| |
| |
|
|
| class myDatasetCTC(Dataset): |
|
|
| def __init__(self, partition = "train"): |
| |
| self.processor = None |
| self.partition = partition |
|
|
| self.path_labels = paths.IMAGE_PATH |
| self.path_images = paths.GT_PATH |
| self.image_name_list = [] |
| self.label_list = [] |
|
|
| f = open(self.path_labels, 'r') |
| Lines = f.readlines() |
| |
| for line in Lines: |
| line = line.strip().split() |
| self.image_name_list.append(self.path_images + line[0]) |
| self.label_list.append(' '.join(line[1:])) |
|
|
| print("\tSamples Loaded: ", len(self.label_list), "\n-------------------------------------") |
|
|
| def set_processor(self, processor): |
| self.processor = processor |
|
|
| def __len__(self): |
| return len(self.image_name_list) |
|
|
| def __getitem__(self, idx): |
| |
| with Image.open(self.image_name_list[idx]) as image: |
| image = image.convert("RGB") |
| image_tensor = np.array(image) |
| label = self.label_list[idx] |
|
|
| image_tensor = self.processor( |
| image_tensor, |
| random_padding=self.partitions == "train", |
| return_tensors="pt" |
| ).pixel_values |
| image_tensor = image_tensor.squeeze() |
| |
| |
| label_tensor = torch.tensor(sample_text_to_seq(label, self.text_to_seq)) |
| |
| return {"idx": idx, "img": image_tensor, "label": label_tensor, "raw_label": label} |
|
|
|
|
| |
| |
| |
|
|
| class myDatasetTransformerDecoder(Dataset): |
| def __init__(self, partition="train"): |
| |
| self.max_length = paths.MAX_LENGTH |
| self.partition = partition |
| self.processor = None |
| self.ignore_id = -100 |
|
|
| self.path_img = paths.IMAGE_PATH |
| self.path_transcriptions = paths.GT_PATH |
| self.image_name_list = [] |
| self.label_list = [] |
|
|
| template = '{"gt_parse": {"text_sequence" : ' |
| with open(self.path_transcriptions, 'r') as file: |
| for line in file: |
| line = line.strip().split() |
|
|
| image_name = line[0] |
| label_gt = ' '.join(line[1:]) |
| label_gt = template + '"' + label_gt + '"' + "}}" |
|
|
| self.image_name_list.append(self.path_img + image_name) |
| self.label_list.append(label_gt) |
| |
| print("\tSamples Loaded: ", len(self.label_list)) |
|
|
| def dict2token(self, obj: Any): |
| return obj["text_sequence"] |
| |
| def set_processor(self, processor): |
| self.processor = processor |
|
|
| def __len__(self): |
| return len(self.image_name_list) |
|
|
| def __getitem__(self, idx): |
| |
| image = Image.open(self.image_name_list[idx]).convert("RGB") |
| image_tensor = np.array(image) |
|
|
| pixel_values: torch.Tensor = self.processor(image_tensor, random_padding=self.partition == "train", return_tensors="pt").pixel_values[0] |
|
|
| label = self.label_list[idx] |
| label = literal_eval(label) |
| assert "gt_parse" in label and isinstance(label["gt_parse"], dict) |
| gt_dicts = [label["gt_parse"]] |
| target_sequence=[self.dict2token(gt_dict) + self.processor.tokenizer.eos_token for gt_dict in gt_dicts] |
| |
| input_ids = self.processor.tokenizer( |
| target_sequence, |
| add_special_tokens=False, |
| max_length=self.max_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt", |
| )["input_ids"].squeeze(0) |
|
|
| labels = input_ids.clone() |
| labels[labels == self.processor.tokenizer.pad_token_id] = self.ignore_id |
|
|
| return {"idx": idx, "img": pixel_values, "label": labels, "raw_label": target_sequence} |