File size: 1,460 Bytes
ba0c78a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4db1ad4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
from PIL import Image
import time
import torch
import argparse
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from dataset import decode_text

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='trocr fine-tune训练')
    parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str,
                        help="初始化训练权重,用于自己数据集上fine-tune权重")
    parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置")
    parser.add_argument('--test_img', default='test/test.jpg', type=str, help="img path")

    args = parser.parse_args()

    processor = TrOCRProcessor.from_pretrained(args.cust_data_init_weights_path)
    vocab = processor.tokenizer.get_vocab()

    vocab_inp = {vocab[key]: key for key in vocab}
    model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path)
    model.eval()

    vocab = processor.tokenizer.get_vocab()
    vocab_inp = {vocab[key]: key for key in vocab}

    t = time.time()
    img = Image.open(args.test_img).convert('RGB')
    pixel_values = processor([img], return_tensors="pt").pixel_values
    with torch.no_grad():
        generated_ids = model.generate(pixel_values[:, :, :].cpu())

    generated_text = decode_text(generated_ids[0].cpu().numpy(), vocab, vocab_inp)
    print('time take:', round(time.time() - t, 2), "s ocr:", [generated_text])