|
|
import os |
|
|
from PIL import Image |
|
|
from torch.utils.data import Dataset |
|
|
import torch |
|
|
class trocrDataset(Dataset): |
|
|
""" |
|
|
trocr 训练数据集处理 |
|
|
文件数据结构 |
|
|
/tmp/0/0.jpg #image |
|
|
/tmp/0/0.txt #text label |
|
|
.... |
|
|
/tmp/100/10000.jpg #image |
|
|
/tmp/100/10000.txt #text label |
|
|
""" |
|
|
def __init__(self, paths, processor, max_target_length=128, transformer=lambda x:x): |
|
|
self.paths = paths |
|
|
self.processor = processor |
|
|
self.transformer = transformer |
|
|
self.max_target_length = max_target_length |
|
|
self.nsamples = len(self.paths) |
|
|
self.vocab = processor.tokenizer.get_vocab() |
|
|
|
|
|
def __len__(self): |
|
|
return self.nsamples |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
inx = idx % self.nsamples |
|
|
image_file = self.paths[idx] |
|
|
txt_file = os.path.splitext(image_file)[0]+'.txt' |
|
|
|
|
|
with open(txt_file) as f: |
|
|
text = f.read().strip().replace('xa0','') |
|
|
if text.startswith('[') and text.endswith(']'): |
|
|
|
|
|
try: |
|
|
text = json.loads(text) |
|
|
except: |
|
|
pass |
|
|
|
|
|
image = Image.open(image_file).convert("RGB") |
|
|
image = self.transformer(image) |
|
|
pixel_values = self.processor(image, return_tensors="pt").pixel_values |
|
|
|
|
|
labels = encode_text(text, max_target_length=self.max_target_length, vocab=self.vocab) |
|
|
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] |
|
|
|
|
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} |
|
|
|
|
|
return encoding |
|
|
|
|
|
|
|
|
def encode_text(text, max_target_length=128, vocab=None): |
|
|
""" |
|
|
##自持自定义 list: ['<td>',"3","3",'</td>',....] |
|
|
{'input_ids': [0, 1092, 2, 1, 1], |
|
|
'attention_mask': [1, 1, 1, 0, 0]} |
|
|
""" |
|
|
if type(text) is not list: |
|
|
text = list(text) |
|
|
|
|
|
text = text[:max_target_length - 2] |
|
|
tokens = [vocab.get('<s>')] |
|
|
unk = vocab.get('<unk>') |
|
|
pad = vocab.get('<pad>') |
|
|
mask = [] |
|
|
for tk in text: |
|
|
token = vocab.get(tk, unk) |
|
|
tokens.append(token) |
|
|
mask.append(1) |
|
|
|
|
|
tokens.append(vocab.get('</s>')) |
|
|
mask.append(1) |
|
|
|
|
|
if len(tokens) < max_target_length: |
|
|
for i in range(max_target_length - len(tokens)): |
|
|
tokens.append(pad) |
|
|
mask.append(0) |
|
|
|
|
|
return tokens |
|
|
|
|
|
|
|
|
|
|
|
def decode_text(tokens, vocab, vocab_inp): |
|
|
|
|
|
s_start = vocab.get('<s>') |
|
|
s_end = vocab.get('</s>') |
|
|
unk = vocab.get('<unk>') |
|
|
pad = vocab.get('<pad>') |
|
|
text = '' |
|
|
for tk in tokens: |
|
|
if tk not in [s_end, s_start , pad, unk]: |
|
|
text += vocab_inp[tk] |
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|