| from typing import NamedTuple |
| from argparse import ArgumentParser |
|
|
| from tqdm import tqdm |
| import logging |
|
|
| import numpy as np |
| import torch as T |
| from torch.nn import functional as F |
|
|
| import diac_utils as du |
|
|
| _x = [ |
| 'a' |
| ] |
|
|
| |
| logger = logging.getLogger(__file__) |
| logger.setLevel(logging.INFO) |
|
|
| def logln(*texts: str): |
| |
| print(*texts) |
|
|
| |
| |
| |
| |
|
|
| class PartialDiacMetrics(NamedTuple): |
| diff_total: float |
| worse_total: float |
| diff_relative: float |
| der_total: float |
| selectivity: float |
| hidden_der: float |
| partial_der: float |
| reader_error: float |
|
|
| def load_data(path: str): |
| if path.endswith('.txt'): |
| with open(path, 'r', encoding='utf-8') as fin: |
| return fin.readlines() |
| else: |
| return T.load(path) |
|
|
| def parse_data( |
| data, |
| logits: bool = False, |
| side=None, |
| ): |
| if logits: |
| ld = data['line_data'] |
| diac_logits = T.tensor(ld[f'diac_logits_{side}']) |
| |
| diac_pred: T.Tensor = diac_logits.argmax(dim=-1) |
| diac_gt : T.Tensor = ld['diac_gt'] |
| |
| return diac_pred, diac_gt, diac_logits |
| if isinstance(data, dict): |
| ld = data.get('line_data_fix', data['line_data']) |
| if side is None: |
| diac_pred: T.Tensor = ld['diac_pred'] |
| else: |
| diac_pred: T.Tensor = ld[f'diac_logits_{side}'].argmax(axis=-1) |
| diac_gt : T.Tensor = ld['diac_gt'] |
| return diac_pred, diac_gt |
| elif isinstance(data, list): |
| data_indices = [ |
| du.diac_ids_of_line(du.strip_tatweel(du.normalize_spaces(line))) |
| for line in data |
| ] |
| max_len = max(map(len, data_indices)) |
| out = np.full((len(data), max_len), fill_value=du.DIAC_PAD_IDX) |
| for i_line, line_indices in enumerate(data_indices): |
| out[i_line][:len(line_indices)] = line_indices |
| return out, None |
| elif isinstance(data, (T.Tensor, np.ndarray)): |
| return data, None |
| else: |
| raise NotImplementedError |
|
|
| def make_mask_hard( |
| pred_c: T.Tensor, |
| pred_m: T.Tensor, |
| ): |
| selection = (pred_c != pred_m) |
| return selection |
|
|
| def make_mask_logits( |
| pred_c: T.Tensor, |
| pred_m: T.Tensor, |
| threshold: float = 0.1, |
| version: str = '2', |
| ) -> T.BoolTensor: |
| logger.warning(f"{version=}, {threshold=}") |
| pred_c = T.softmax(T.tensor(pred_c), dim=-1) |
| pred_m = T.softmax(T.tensor(pred_m), dim=-1) |
| |
| if version == 'hard': |
| selection = pred_c.argmax(-1) != pred_m.argmax(-1) |
| elif version == '0': |
| selection = pred_c.max(dim=-1).values > pred_m.max(dim=-1).values |
| selection = selection & (pred_m.max(dim=-1).values > threshold) |
| elif version == '1': |
| pred_c_conf = pred_c.max(dim=-1).values |
| pred_m_conf = pred_m.max(dim=-1).values |
| selection = (pred_c_conf - pred_m_conf) > threshold |
| elif version == '1.1': |
| pred_c_conf = pred_c.max(dim=-1).values |
| pred_m_conf = pred_m.max(dim=-1).values |
| selection = (pred_c_conf - pred_m_conf).abs() > threshold |
| elif version.startswith('2'): |
| if version == '2': |
| max_c = pred_c.argmax(dim=-1, keepdims=True) |
| selection = T.gather(pred_c - pred_m, dim=-1, index=max_c) > threshold |
| elif version == '2.1': |
| max_c = pred_m.argmax(dim=-1, keepdims=True) |
| selection = T.gather(pred_c - pred_m, dim=-1, index=max_c) > threshold |
| elif version == '2.abs': |
| max_c = pred_c.argmax(dim=-1, keepdims=True) |
| selection = T.gather(pred_c - pred_m, dim=-1, index=max_c).abs() > threshold |
| elif version == '2.1.abs': |
| max_c = pred_m.argmax(dim=-1, keepdims=True) |
| selection = T.gather(pred_c - pred_m, dim=-1, index=max_c).abs() > threshold |
| elif version == '3': |
| selection = (pred_c - pred_m).max(dim=-1).values > threshold |
| elif version == '4': |
| selection_hard = (pred_c.argmax(-1) != pred_m.argmax(-1)) |
| |
| selection_logits = T.gather(pred_c - pred_m, dim=-1, index=pred_c.argmax(-1, keepdims=True)) > threshold |
| selection = selection_hard & selection_logits.squeeze() |
| |
| return selection.squeeze() |
|
|
| def analysis_summary( |
| pred_c : T.LongTensor, |
| pred_m : T.LongTensor, |
| labels : T.LongTensor, |
| padding_mask: T.BoolTensor, |
| *, |
| selection : T.Tensor = None, |
| random: bool = False, |
| logits: tuple = None |
| ): |
| |
| |
| |
| padding_mask = T.tensor(padding_mask) |
| |
| nonpad_mask = ~padding_mask |
| num_chars = nonpad_mask.sum() |
| |
| if logits is not None: |
| logits = tuple(map(T.tensor, logits)) |
| |
| pred_c = (T.softmax(logits[0], dim=-1) + T.softmax(logits[1], dim=-1)).argmax(-1) |
| pred_c = T.tensor(pred_c)[nonpad_mask] |
| pred_m = T.tensor(pred_m)[nonpad_mask] |
| labels = T.tensor(labels)[nonpad_mask] |
| |
|
|
| ctxt_match = (pred_c == labels).float() |
| base_match = (pred_m == labels).float() |
| |
| selection = T.tensor(selection)[nonpad_mask] |
| if random: |
| selection = pred_c.new_empty(pred_c.shape).bernoulli_(p=selection.float().mean()).to(bool) |
| unselected = ~selection |
|
|
| assert num_chars > 0 |
| assert selection.sum() > 0 |
| base_accuracy = base_match[unselected].sum() / unselected.sum() |
| ctxt_accuracy = ctxt_match[selection].sum() / selection.sum() |
| correct_total = ctxt_match.sum() / num_chars |
| der_total = 1 - correct_total |
| |
| cmp = (ctxt_match - base_match)[selection] |
| diff = T.sum(cmp) |
| diff_total = diff / num_chars |
| diff_relative = diff / selection.sum() |
| |
| selectivity = selection.sum() / num_chars |
| worse_total = base_match[selection].sum() / num_chars |
| |
| hidden_der = 1.0 - base_accuracy |
| partial_der = 1.0 - ctxt_accuracy |
| reader_error = selectivity * partial_der + (1 - selectivity) * hidden_der |
| |
| return PartialDiacMetrics( |
| diff_total = round(diff_total.item() * 100, 2), |
| worse_total = round(worse_total.item() * 100, 2), |
| diff_relative = round(diff_relative.item() * 100, 2), |
| der_total = round(der_total.item() * 100, 2), |
| selectivity = round(selectivity.item() * 100, 2), |
| hidden_der = round(hidden_der.item() * 100, 2), |
| partial_der = round(partial_der.item() * 100, 2), |
| reader_error = round(reader_error.item() * 100, 2) |
| ) |
|
|
|
|
| def relative_improvement_soft( |
| pred_c : T.Tensor, |
| pred_m : T.Tensor, |
| labels : T.LongTensor, |
| padding_mask: T.Tensor, |
| ): |
| |
| |
| padding_mask = T.tensor(padding_mask) |
| nonpad_mask = 1 - padding_mask.float() |
| num_chars = nonpad_mask.sum() |
| |
| pred_c = T.tensor(pred_c)[~padding_mask] |
| pred_m = T.tensor(pred_m)[~padding_mask] |
| |
| labels = T.tensor(labels)[~padding_mask] |
| |
|
|
| ctxt_match = T.gather(pred_c, dim=1, index=labels) |
| base_match = T.gather(pred_m, dim=1, index=labels) |
| selection = (pred_c.argmax(-1) != pred_m.argmax(-1)) |
|
|
| better = T.sum(ctxt_match - base_match) / num_chars |
| selectivity = selection.sum() / num_chars |
| worse = base_match[selection].sum() / num_chars |
| return better, worse, selectivity |
|
|
| def relative_improvement_masked_soft( |
| pred_c: T.Tensor, |
| pred_m: T.Tensor, |
| ground_truth: T.LongTensor, |
| padding_mask: T.Tensor, |
| ): |
| raise NotImplementedError |
| |
| |
| |
| nonpad_mask = 1 - padding_mask |
|
|
| selection_mask = pred_c.argmax(3) != pred_m.argmax(3) |
| |
| probs = F.softmax(pred_c.clone(), dim=-1) |
| probs_gt = T.gather(probs, dim=-1, index=ground_truth.unsqueeze(-1)).squeeze(-1) |
| |
| result = probs_gt[selection_mask & nonpad_mask].mean() |
| return result |
|
|
| def coverage_confidence( |
| pred_c: T.Tensor, |
| pred_m: T.Tensor, |
| padding_mask: T.Tensor, |
| |
| ): |
| raise NotImplementedError |
| |
| |
| |
| pred_c_id = pred_c.argmax(3) |
| pred_m_id = pred_m.argmax(3) |
| selected = pred_c_id[pred_c_id != pred_m_id] |
| nonpad_mask = 1 - padding_mask |
| result = selected.sum() / nonpad_mask.sum() |
| return result |
|
|
| def cli(): |
| parser = ArgumentParser('Compare diacritics from base/ctxt systems with partial diac metrics.') |
| parser.add_argument('-m', '--model-output-base', help="Path to tensor.pt dump files of base diacs.") |
| parser.add_argument('-c', '--model-output-ctxt', help="Path to tensor.pt dump files of ctxt diacs.") |
| parser.add_argument('--gt', default=None, help="Path to tensor.pt for gt only.") |
| parser.add_argument('--mode', choices=['hard', 'logits'], default='hard') |
| args = parser.parse_args() |
| |
| model_output_base = parse_data( |
| load_data(args.model_output_base), |
| |
| logits=True, |
| side='base', |
| ) |
| model_output_ctxt = parse_data( |
| load_data(args.model_output_ctxt), |
| |
| logits=True, |
| side='ctxt', |
| ) |
| |
| diacs_pred = model_output_base |
| |
| logln(f"{model_output_base[0].shape=} , {model_output_ctxt[0].shape=}") |
| |
| assert len(model_output_base[0]) == len(model_output_ctxt[0]) |
| |
| |
| |
| |
| |
| |
|
|
| xc = model_output_ctxt |
| xm = model_output_base |
| |
| |
| |
| |
| |
| |
| if xm[1] is not None: |
| ground_truth = xm[1] |
| elif xc[1] is not None: |
| ground_truth = xc[1] |
| assert ground_truth is not None |
|
|
| if args.mode == 'hard': |
| selection = make_mask_hard(xc[0], xm[0]) |
| elif args.mode == 'logits': |
| selection = make_mask_logits(xc[2], xm[2]) |
|
|
| metrics = analysis_summary( |
| xc[0], xm[0], ground_truth, ground_truth == -1, |
| selection=selection, |
| logits=(xc[2], xm[2]) |
| ) |
| logln("Actual Totals:", metrics) |
| metrics = analysis_summary( |
| xc[0], xm[0], ground_truth, ground_truth == -1, random=True, |
| selection=selection, |
| logits=(xc[2], xm[2]) |
| ) |
| logln("Random Marked Chars:", metrics) |
|
|