File size: 4,582 Bytes
646f45c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | import torch
import os
import re
import json
import valid
from utils import utils
from utils import option
from data import dataset
from model import htr_convtext
from collections import OrderedDict
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(args.seed)
args.save_dir = os.path.join(args.out_dir, args.exp_name)
os.makedirs(args.save_dir, exist_ok=True)
logger = utils.get_logger(args.save_dir)
model = htr_convtext.create_model(
nb_cls=args.nb_cls, img_size=args.img_size[::-1])
pth_path = args.resume
logger.info('loading HWR checkpoint from {}'.format(pth_path))
ckpt = torch.load(pth_path, map_location='cpu', weights_only=False)
model_dict = OrderedDict()
pattern = re.compile('module.')
for k, v in ckpt['state_dict_ema'].items():
if re.search("module", k):
model_dict[re.sub(pattern, '', k)] = v
else:
model_dict[k] = v
model.load_state_dict(model_dict, strict=True)
model = model.cuda()
logger.info('Loading test loader...')
train_dataset = dataset.myLoadDS(
args.train_data_list, args.data_path, args.img_size, dataset=args.dataset)
test_dataset = dataset.myLoadDS(
args.test_data_list, args.data_path, args.img_size, ralph=train_dataset.ralph, dataset=args.dataset)
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=args.val_bs,
shuffle=False,
pin_memory=True,
num_workers=args.num_workers)
converter = utils.CTCLabelConverter(train_dataset.ralph.values())
criterion = torch.nn.CTCLoss(
reduction='none', zero_infinity=True).to(device)
model.eval()
with torch.no_grad():
val_loss, val_cer, val_wer, preds, labels = valid.validation(
model,
criterion,
test_loader,
converter,
)
logger.info(
f'Test. loss : {val_loss:0.3f} \t CER : {val_cer:0.4f} \t WER : {val_wer:0.4f} ')
# Save predictions as JSON
results = {
"test_metrics": {
"loss": float(val_loss),
"cer": float(val_cer),
"wer": float(val_wer)
},
"predictions": []
}
def _levenshtein(pred_tokens, gt_tokens):
if pred_tokens == gt_tokens:
return 0
lp, lg = len(pred_tokens), len(gt_tokens)
if lp == 0:
return lg
if lg == 0:
return lp
prev = list(range(lg + 1))
for i in range(1, lp + 1):
cur = [i]
pi = pred_tokens[i - 1]
for j in range(1, lg + 1):
gj = gt_tokens[j - 1]
cost = 0 if pi == gj else 1
cur.append(
min(prev[j] + 1, cur[j - 1] + 1, prev[j - 1] + cost))
prev = cur
return prev[-1]
def _levenshtein_str(a: str, b: str):
return _levenshtein(list(a), list(b))
def _cer(pred: str, gt: str):
if len(gt) == 0:
return 0.0 if len(pred) == 0 else 1.0
return _levenshtein_str(pred, gt) / len(gt)
def _wer(pred: str, gt: str):
gt_words = gt.split()
pred_words = pred.split()
if len(gt_words) == 0:
return 0.0 if len(pred_words) == 0 else 1.0
return _levenshtein(pred_words, gt_words) / len(gt_words)
for i, (pred, label) in enumerate(zip(preds, labels)):
if i < len(test_dataset.fns):
img_path = test_dataset.fns[i]
img_name = os.path.basename(img_path)
else:
img_path = None
img_name = None
results["predictions"].append({
"sample_id": i + 1,
"image_filename": img_name,
"image_path": img_path,
"prediction": pred,
"ground_truth": label,
"match": pred == label,
"cer": round(float(_cer(pred, label)), 6),
"wer": round(float(_wer(pred, label)), 6)
})
pred_file = os.path.join(args.save_dir, 'predictions.json')
with open(pred_file, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
if __name__ == '__main__':
args = option.get_args_parser()
main()
|