import torch import torch.utils.data import torch.backends.cudnn as cudnn from utils import utils import editdistance def validation(model, criterion, evaluation_loader, converter): """ validation or evaluation """ norm_ED = 0 norm_ED_wer = 0 tot_ED = 0 tot_ED_wer = 0 valid_loss = 0.0 length_of_gt = 0 length_of_gt_wer = 0 count = 0 all_preds_str = [] all_labels = [] for i, (image_tensors, labels) in enumerate(evaluation_loader): batch_size = image_tensors.size(0) image = image_tensors.cuda() text_for_loss, length_for_loss = converter.encode(labels) preds = model(image) preds = preds.float() preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2).log_softmax(2) torch.backends.cudnn.enabled = False cost = criterion(preds, text_for_loss, preds_size, length_for_loss).mean() torch.backends.cudnn.enabled = True _, preds_index = preds.max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) valid_loss += cost.item() count += 1 all_preds_str.extend(preds_str) all_labels.extend(labels) for pred_cer, gt_cer in zip(preds_str, labels): tmp_ED = editdistance.eval(pred_cer, gt_cer) if len(gt_cer) == 0: norm_ED += 1 else: norm_ED += tmp_ED / float(len(gt_cer)) tot_ED += tmp_ED length_of_gt += len(gt_cer) for pred_wer, gt_wer in zip(preds_str, labels): pred_wer = utils.format_string_for_wer(pred_wer) gt_wer = utils.format_string_for_wer(gt_wer) pred_wer = pred_wer.split(" ") gt_wer = gt_wer.split(" ") tmp_ED_wer = editdistance.eval(pred_wer, gt_wer) if len(gt_wer) == 0: norm_ED_wer += 1 else: norm_ED_wer += tmp_ED_wer / float(len(gt_wer)) tot_ED_wer += tmp_ED_wer length_of_gt_wer += len(gt_wer) val_loss = valid_loss / count CER = tot_ED / float(length_of_gt) WER = tot_ED_wer / float(length_of_gt_wer) return val_loss, CER, WER, all_preds_str, all_labels