import os from PIL import Image import time import torch import argparse from transformers import TrOCRProcessor, VisionEncoderDecoderModel from dataset import decode_text if __name__ == '__main__': parser = argparse.ArgumentParser(description='trocr fine-tune训练') parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str, help="初始化训练权重,用于自己数据集上fine-tune权重") parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置") parser.add_argument('--test_img', default='test/test.jpg', type=str, help="img path") args = parser.parse_args() processor = TrOCRProcessor.from_pretrained(args.cust_data_init_weights_path) vocab = processor.tokenizer.get_vocab() vocab_inp = {vocab[key]: key for key in vocab} model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path) model.eval() vocab = processor.tokenizer.get_vocab() vocab_inp = {vocab[key]: key for key in vocab} t = time.time() img = Image.open(args.test_img).convert('RGB') pixel_values = processor([img], return_tensors="pt").pixel_values with torch.no_grad(): generated_ids = model.generate(pixel_values[:, :, :].cpu()) generated_text = decode_text(generated_ids[0].cpu().numpy(), vocab, vocab_inp) print('time take:', round(time.time() - t, 2), "s ocr:", [generated_text])