|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from typing import Any |
|
|
from ast import literal_eval |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
import paths |
|
|
from utils_ctc import sample_text_to_seq |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class myDatasetCTC(Dataset): |
|
|
|
|
|
def __init__(self, partition = "train"): |
|
|
|
|
|
self.processor = None |
|
|
self.partition = partition |
|
|
|
|
|
self.path_labels = paths.IMAGE_PATH |
|
|
self.path_images = paths.GT_PATH |
|
|
self.image_name_list = [] |
|
|
self.label_list = [] |
|
|
|
|
|
f = open(self.path_labels, 'r') |
|
|
Lines = f.readlines() |
|
|
|
|
|
for line in Lines: |
|
|
line = line.strip().split() |
|
|
self.image_name_list.append(self.path_images + line[0]) |
|
|
self.label_list.append(' '.join(line[1:])) |
|
|
|
|
|
print("\tSamples Loaded: ", len(self.label_list), "\n-------------------------------------") |
|
|
|
|
|
def set_processor(self, processor): |
|
|
self.processor = processor |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_name_list) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
with Image.open(self.image_name_list[idx]) as image: |
|
|
image = image.convert("RGB") |
|
|
image_tensor = np.array(image) |
|
|
label = self.label_list[idx] |
|
|
|
|
|
image_tensor = self.processor( |
|
|
image_tensor, |
|
|
random_padding=self.partitions == "train", |
|
|
return_tensors="pt" |
|
|
).pixel_values |
|
|
image_tensor = image_tensor.squeeze() |
|
|
|
|
|
|
|
|
label_tensor = torch.tensor(sample_text_to_seq(label, self.text_to_seq)) |
|
|
|
|
|
return {"idx": idx, "img": image_tensor, "label": label_tensor, "raw_label": label} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class myDatasetTransformerDecoder(Dataset): |
|
|
def __init__(self, partition="train"): |
|
|
|
|
|
self.max_length = paths.MAX_LENGTH |
|
|
self.partition = partition |
|
|
self.processor = None |
|
|
self.ignore_id = -100 |
|
|
|
|
|
self.path_img = paths.IMAGE_PATH |
|
|
self.path_transcriptions = paths.GT_PATH |
|
|
self.image_name_list = [] |
|
|
self.label_list = [] |
|
|
|
|
|
template = '{"gt_parse": {"text_sequence" : ' |
|
|
with open(self.path_transcriptions, 'r') as file: |
|
|
for line in file: |
|
|
line = line.strip().split() |
|
|
|
|
|
image_name = line[0] |
|
|
label_gt = ' '.join(line[1:]) |
|
|
label_gt = template + '"' + label_gt + '"' + "}}" |
|
|
|
|
|
self.image_name_list.append(self.path_img + image_name) |
|
|
self.label_list.append(label_gt) |
|
|
|
|
|
print("\tSamples Loaded: ", len(self.label_list)) |
|
|
|
|
|
def dict2token(self, obj: Any): |
|
|
return obj["text_sequence"] |
|
|
|
|
|
def set_processor(self, processor): |
|
|
self.processor = processor |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_name_list) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
image = Image.open(self.image_name_list[idx]).convert("RGB") |
|
|
image_tensor = np.array(image) |
|
|
|
|
|
pixel_values: torch.Tensor = self.processor(image_tensor, random_padding=self.partition == "train", return_tensors="pt").pixel_values[0] |
|
|
|
|
|
label = self.label_list[idx] |
|
|
label = literal_eval(label) |
|
|
assert "gt_parse" in label and isinstance(label["gt_parse"], dict) |
|
|
gt_dicts = [label["gt_parse"]] |
|
|
target_sequence=[self.dict2token(gt_dict) + self.processor.tokenizer.eos_token for gt_dict in gt_dicts] |
|
|
|
|
|
input_ids = self.processor.tokenizer( |
|
|
target_sequence, |
|
|
add_special_tokens=False, |
|
|
max_length=self.max_length, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
)["input_ids"].squeeze(0) |
|
|
|
|
|
labels = input_ids.clone() |
|
|
labels[labels == self.processor.tokenizer.pad_token_id] = self.ignore_id |
|
|
|
|
|
return {"idx": idx, "img": pixel_values, "label": labels, "raw_label": target_sequence} |