File size: 2,821 Bytes
ba0c78a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4db1ad4
b1c66ba
ba0c78a
 
 
 
11d47e4
ba0c78a
 
4db1ad4
ba0c78a
 
 
 
 
 
 
 
 
 
b1c66ba
ba0c78a
 
 
11d47e4
b1c66ba
ba0c78a
11d47e4
ba0c78a
 
 
 
 
 
 
 
 
 
 
 
 
11d47e4
ba0c78a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
from PIL import Image
import numpy as np
import time
import torch
import argparse
from glob import glob
from sklearn.model_selection import train_test_split
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from dataset import decode_text
from tqdm import tqdm
from datasets import load_metric

cer_metric = load_metric("./cer.py")


def compute_metrics(pred_str, label_str):
    """
    计算cer,acc
    :param pred:
    :return:
    """
    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    acc = [pred == label for pred, label in zip(pred_str, label_str)]
    acc = sum(acc) / (len(acc) + 0.000001)
    return {"cer": cer, "acc": acc}


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='trocr 模型评估')
    parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str,
                        help="初始化训练权重,用于自己数据集上fine-tune权重")
    parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置")
    parser.add_argument('--dataset_path', default='dataset/HW-hand-write/HW_Chinese/*/*.[j|p]*', type=str,
                        help="img path")
    parser.add_argument('--random_state', default=10086, type=int, help="用于训练集划分的随机数")

    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.CUDA_VISIBLE_DEVICES
    paths = glob(args.dataset_path)
    if args.random_state is not None:
        train_paths, test_paths = train_test_split(paths, test_size=0.05, random_state=args.random_state)

    else:
        train_paths = []
        test_paths = paths

    print("train num:", len(train_paths), "test num:", len(test_paths))

    processor = TrOCRProcessor.from_pretrained(args.cust_data_init_weights_path)
    vocab = processor.tokenizer.get_vocab()

    vocab_inp = {vocab[key]: key for key in vocab}
    mps_device = torch.device("mps")
    model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path)
    model.eval()
    model.to(mps_device)

    vocab = processor.tokenizer.get_vocab()
    vocab_inp = {vocab[key]: key for key in vocab}

    pred_str, label_str = [], []
    for p in tqdm(test_paths):
        img = Image.open(p).convert('RGB')
        txt_p = os.path.splitext(p)[0] + '.txt'
        with open(txt_p) as f:
            label = f.read().strip()
        pixel_values = processor([img], return_tensors="pt").pixel_values

        with torch.no_grad():
            generated_ids = model.generate(pixel_values[:, :, :].to(mps_device))

        generated_text = decode_text(generated_ids[0].cpu().numpy(), vocab, vocab_inp)
        pred_str.append(generated_text)
        label_str.append(label)

    res = compute_metrics(pred_str, label_str)
    print(res)