import torch import os import re import json import valid from utils import utils from utils import option from data import dataset from model import htr_convtext from collections import OrderedDict def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch.manual_seed(args.seed) args.save_dir = os.path.join(args.out_dir, args.exp_name) os.makedirs(args.save_dir, exist_ok=True) logger = utils.get_logger(args.save_dir) model = htr_convtext.create_model( nb_cls=args.nb_cls, img_size=args.img_size[::-1]) pth_path = args.resume logger.info('loading HWR checkpoint from {}'.format(pth_path)) ckpt = torch.load(pth_path, map_location='cpu', weights_only=False) model_dict = OrderedDict() pattern = re.compile('module.') for k, v in ckpt['state_dict_ema'].items(): if re.search("module", k): model_dict[re.sub(pattern, '', k)] = v else: model_dict[k] = v model.load_state_dict(model_dict, strict=True) model = model.cuda() logger.info('Loading test loader...') train_dataset = dataset.myLoadDS( args.train_data_list, args.data_path, args.img_size, dataset=args.dataset) test_dataset = dataset.myLoadDS( args.test_data_list, args.data_path, args.img_size, ralph=train_dataset.ralph, dataset=args.dataset) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.val_bs, shuffle=False, pin_memory=True, num_workers=args.num_workers) converter = utils.CTCLabelConverter(train_dataset.ralph.values()) criterion = torch.nn.CTCLoss( reduction='none', zero_infinity=True).to(device) model.eval() with torch.no_grad(): val_loss, val_cer, val_wer, preds, labels = valid.validation( model, criterion, test_loader, converter, ) logger.info( f'Test. loss : {val_loss:0.3f} \t CER : {val_cer:0.4f} \t WER : {val_wer:0.4f} ') # Save predictions as JSON results = { "test_metrics": { "loss": float(val_loss), "cer": float(val_cer), "wer": float(val_wer) }, "predictions": [] } def _levenshtein(pred_tokens, gt_tokens): if pred_tokens == gt_tokens: return 0 lp, lg = len(pred_tokens), len(gt_tokens) if lp == 0: return lg if lg == 0: return lp prev = list(range(lg + 1)) for i in range(1, lp + 1): cur = [i] pi = pred_tokens[i - 1] for j in range(1, lg + 1): gj = gt_tokens[j - 1] cost = 0 if pi == gj else 1 cur.append( min(prev[j] + 1, cur[j - 1] + 1, prev[j - 1] + cost)) prev = cur return prev[-1] def _levenshtein_str(a: str, b: str): return _levenshtein(list(a), list(b)) def _cer(pred: str, gt: str): if len(gt) == 0: return 0.0 if len(pred) == 0 else 1.0 return _levenshtein_str(pred, gt) / len(gt) def _wer(pred: str, gt: str): gt_words = gt.split() pred_words = pred.split() if len(gt_words) == 0: return 0.0 if len(pred_words) == 0 else 1.0 return _levenshtein(pred_words, gt_words) / len(gt_words) for i, (pred, label) in enumerate(zip(preds, labels)): if i < len(test_dataset.fns): img_path = test_dataset.fns[i] img_name = os.path.basename(img_path) else: img_path = None img_name = None results["predictions"].append({ "sample_id": i + 1, "image_filename": img_name, "image_path": img_path, "prediction": pred, "ground_truth": label, "match": pred == label, "cer": round(float(_cer(pred, label)), 6), "wer": round(float(_wer(pred, label)), 6) }) pred_file = os.path.join(args.save_dir, 'predictions.json') with open(pred_file, 'w', encoding='utf-8') as f: json.dump(results, f, indent=2, ensure_ascii=False) if __name__ == '__main__': args = option.get_args_parser() main()