import os from PIL import Image import numpy as np import time import torch import argparse from glob import glob from sklearn.model_selection import train_test_split from transformers import TrOCRProcessor, VisionEncoderDecoderModel from dataset import decode_text from tqdm import tqdm from datasets import load_metric cer_metric = load_metric("./cer.py") def compute_metrics(pred_str, label_str): """ 计算cer,acc :param pred: :return: """ cer = cer_metric.compute(predictions=pred_str, references=label_str) acc = [pred == label for pred, label in zip(pred_str, label_str)] acc = sum(acc) / (len(acc) + 0.000001) return {"cer": cer, "acc": acc} if __name__ == '__main__': parser = argparse.ArgumentParser(description='trocr 模型评估') parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str, help="初始化训练权重,用于自己数据集上fine-tune权重") parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置") parser.add_argument('--dataset_path', default='dataset/HW-hand-write/HW_Chinese/*/*.[j|p]*', type=str, help="img path") parser.add_argument('--random_state', default=10086, type=int, help="用于训练集划分的随机数") args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = args.CUDA_VISIBLE_DEVICES paths = glob(args.dataset_path) if args.random_state is not None: train_paths, test_paths = train_test_split(paths, test_size=0.05, random_state=args.random_state) else: train_paths = [] test_paths = paths print("train num:", len(train_paths), "test num:", len(test_paths)) processor = TrOCRProcessor.from_pretrained(args.cust_data_init_weights_path) vocab = processor.tokenizer.get_vocab() vocab_inp = {vocab[key]: key for key in vocab} mps_device = torch.device("mps") model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path) model.eval() model.to(mps_device) vocab = processor.tokenizer.get_vocab() vocab_inp = {vocab[key]: key for key in vocab} pred_str, label_str = [], [] for p in tqdm(test_paths): img = Image.open(p).convert('RGB') txt_p = os.path.splitext(p)[0] + '.txt' with open(txt_p) as f: label = f.read().strip() pixel_values = processor([img], return_tensors="pt").pixel_values with torch.no_grad(): generated_ids = model.generate(pixel_values[:, :, :].to(mps_device)) generated_text = decode_text(generated_ids[0].cpu().numpy(), vocab, vocab_inp) pred_str.append(generated_text) label_str.append(label) res = compute_metrics(pred_str, label_str) print(res)