import os from PIL import Image from torch.utils.data import Dataset import torch class trocrDataset(Dataset): """ trocr 训练数据集处理 文件数据结构 /tmp/0/0.jpg #image /tmp/0/0.txt #text label .... /tmp/100/10000.jpg #image /tmp/100/10000.txt #text label """ def __init__(self, paths, processor, max_target_length=128, transformer=lambda x:x): self.paths = paths self.processor = processor self.transformer = transformer self.max_target_length = max_target_length self.nsamples = len(self.paths) self.vocab = processor.tokenizer.get_vocab() def __len__(self): return self.nsamples def __getitem__(self, idx): inx = idx % self.nsamples image_file = self.paths[idx] txt_file = os.path.splitext(image_file)[0]+'.txt' with open(txt_file) as f: text = f.read().strip().replace('xa0','') if text.startswith('[') and text.endswith(']'): ##list try: text = json.loads(text) except: pass image = Image.open(image_file).convert("RGB") image = self.transformer(image) ##图像增强函数 pixel_values = self.processor(image, return_tensors="pt").pixel_values #labels = encode_text(text, max_target_length=self.max_target_length, vocab=self.vocab)["input_ids"] labels = encode_text(text, max_target_length=self.max_target_length, vocab=self.vocab) labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} return encoding def encode_text(text, max_target_length=128, vocab=None): """ ##自持自定义 list: ['