| import os | |
| from PIL import Image | |
| import time | |
| import torch | |
| import argparse | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| from dataset import decode_text | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='trocr fine-tune训练') | |
| parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str, | |
| help="初始化训练权重,用于自己数据集上fine-tune权重") | |
| parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置") | |
| parser.add_argument('--test_img', default='test/test.jpg', type=str, help="img path") | |
| args = parser.parse_args() | |
| processor = TrOCRProcessor.from_pretrained(args.cust_data_init_weights_path) | |
| vocab = processor.tokenizer.get_vocab() | |
| vocab_inp = {vocab[key]: key for key in vocab} | |
| model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path) | |
| model.eval() | |
| vocab = processor.tokenizer.get_vocab() | |
| vocab_inp = {vocab[key]: key for key in vocab} | |
| t = time.time() | |
| img = Image.open(args.test_img).convert('RGB') | |
| pixel_values = processor([img], return_tensors="pt").pixel_values | |
| with torch.no_grad(): | |
| generated_ids = model.generate(pixel_values[:, :, :].cpu()) | |
| generated_text = decode_text(generated_ids[0].cpu().numpy(), vocab, vocab_inp) | |
| print('time take:', round(time.time() - t, 2), "s ocr:", [generated_text]) | |