| | import re |
| | from collections import defaultdict |
| |
|
| | import numpy as np |
| | from tqdm.auto import tqdm |
| |
|
| | from rdkit import Chem, RDLogger, DataStructs |
| | from rdkit.Chem import MACCSkeys, AllChem |
| | from rdkit.Chem.AllChem import AssignStereochemistry |
| | from rdchiral.chiral import copy_chirality |
| |
|
| |
|
| | from transformers import BertTokenizerFast |
| |
|
| | from nltk.translate.bleu_score import corpus_bleu |
| | from nltk.translate.meteor_score import meteor_score |
| | from rouge_score import rouge_scorer |
| | from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, matthews_corrcoef |
| |
|
| | RDLogger.DisableLog('rdApp.*') |
| |
|
| |
|
| | def canonicalize(smiles, isomeric=False, canonical=True, kekulize=False): |
| | |
| | |
| | |
| | |
| | |
| | |
| | def copy_atom(atom): |
| | new_atom = Chem.Atom(atom.GetSymbol()) |
| | new_atom.SetFormalCharge(atom.GetFormalCharge()) |
| | if atom.GetIsAromatic() and atom.GetNoImplicit(): |
| | new_atom.SetNumExplicitHs(atom.GetNumExplicitHs()) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | return new_atom |
| |
|
| | def copy_edit_mol(mol): |
| | new_mol = Chem.RWMol(Chem.MolFromSmiles('')) |
| | for atom in mol.GetAtoms(): |
| | new_atom = copy_atom(atom) |
| | new_mol.AddAtom(new_atom) |
| | for bond in mol.GetBonds(): |
| | a1 = bond.GetBeginAtom().GetIdx() |
| | a2 = bond.GetEndAtom().GetIdx() |
| | bt = bond.GetBondType() |
| | new_mol.AddBond(a1, a2, bt) |
| | new_bond = new_mol.GetBondBetweenAtoms(a1, a2) |
| | new_bond.SetBondDir(bond.GetBondDir()) |
| | new_bond.SetStereo(bond.GetStereo()) |
| | for new_atom in new_mol.GetAtoms(): |
| | atom = mol.GetAtomWithIdx(new_atom.GetIdx()) |
| | copy_chirality(atom, new_atom) |
| | return new_mol |
| |
|
| | smiles = smiles.replace(" ", "") |
| | tmp = Chem.MolFromSmiles(smiles, sanitize=False) |
| | tmp.UpdatePropertyCache() |
| | new_mol = copy_edit_mol(tmp) |
| | |
| | if not kekulize: |
| | Chem.SanitizeMol(new_mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_SETAROMATICITY | Chem.SanitizeFlags.SANITIZE_PROPERTIES | Chem.SanitizeFlags.SANITIZE_ADJUSTHS, catchErrors=True) |
| | else: |
| | Chem.SanitizeMol(new_mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE | Chem.SanitizeFlags.SANITIZE_PROPERTIES | Chem.SanitizeFlags.SANITIZE_ADJUSTHS, catchErrors=True) |
| | |
| | AssignStereochemistry(new_mol, cleanIt=False, force=True, flagPossibleStereoCenters=True) |
| | |
| | new_smiles = Chem.MolToSmiles(new_mol, isomericSmiles=isomeric, canonical=canonical) |
| | return new_smiles |
| |
|
| |
|
| | def canonicalize_molecule_smiles(smiles, return_none_for_error=True, skip_mol=False, sort_things=True, isomeric=True, kekulization=True, allow_empty_part=False): |
| | things = smiles.split('.') |
| | if skip_mol: |
| | new_things = things |
| | else: |
| | new_things = [] |
| | for thing in things: |
| | try: |
| | if thing == '' and not allow_empty_part: |
| | raise ValueError('SMILES contains empty part.') |
| |
|
| | mol = Chem.MolFromSmiles(thing) |
| | assert mol is not None |
| | for atom in mol.GetAtoms(): |
| | atom.SetAtomMapNum(0) |
| | thing_smiles = Chem.MolToSmiles(mol, kekuleSmiles=False, isomericSmiles=isomeric) |
| | thing_smiles = Chem.MolFromSmiles(thing_smiles) |
| | thing_smiles = Chem.MolToSmiles(thing_smiles, kekuleSmiles=False, isomericSmiles=isomeric) |
| | thing_smiles = Chem.MolFromSmiles(thing_smiles) |
| | thing_smiles = Chem.MolToSmiles(thing_smiles, kekuleSmiles=False, isomericSmiles=isomeric) |
| | assert thing_smiles is not None |
| | can_in = thing_smiles |
| | can_out = canonicalize(thing_smiles, isomeric=isomeric) |
| | assert can_out is not None, can_in |
| | thing_smiles = can_out |
| | if kekulization: |
| | thing_smiles = keku_mid = Chem.MolFromSmiles(thing_smiles) |
| | assert keku_mid is not None, 'Before can: %s\nAfter can: %s' % (can_in, can_out) |
| | thing_smiles = Chem.MolToSmiles(thing_smiles, kekuleSmiles=True, isomericSmiles=isomeric) |
| | except KeyboardInterrupt: |
| | raise |
| | except: |
| | if return_none_for_error: |
| | return None |
| | else: |
| | raise |
| | new_things.append(thing_smiles) |
| | if sort_things: |
| | new_things = sorted(new_things) |
| | new_things = '.'.join(new_things) |
| | return new_things |
| |
|
| |
|
| | def canonicalize_reaction_smiles(smiles, return_none_for_error=True, return_segs=False, skip_mol=False, sort_things=True, isomeric=True, kekulization=True): |
| | segs = smiles.split('>') |
| | assert len(segs) == 3 |
| | new_segs = [] |
| | for seg in segs: |
| | if seg != '': |
| | new_things = canonicalize_molecule_smiles(seg, return_none_for_error=return_none_for_error, skip_mol=skip_mol, sort_things=sort_things, isomeric=isomeric, kekulization=kekulization) |
| | if return_none_for_error and new_things is None: |
| | return None |
| | new_segs.append(new_things) |
| | else: |
| | new_segs.append('') |
| | |
| | if return_segs: |
| | return tuple(new_segs) |
| | |
| | smiles = '>'.join(new_segs) |
| | return smiles |
| |
|
| |
|
| | def get_molecule_id(smiles, remove_duplicate=True): |
| | if remove_duplicate: |
| | assert ';' not in smiles |
| | all_inchi = set() |
| | for part in smiles.split('.'): |
| | inchi = get_molecule_id(part, remove_duplicate=False) |
| | all_inchi.add(inchi) |
| | all_inchi = tuple(sorted(all_inchi)) |
| | return all_inchi |
| | else: |
| | mol = Chem.MolFromSmiles(smiles) |
| | return Chem.MolToInchi(mol) |
| |
|
| |
|
| | def convert_smiles_list_into_mol_list(smiles_list, raise_error_when_error=False): |
| | mol_list = [] |
| | no_answer_labels = [] |
| | invalid_labels = [] |
| | for smiles in smiles_list: |
| | if smiles == '': |
| | mol = 'NA' |
| | no_answer_labels.append(True) |
| | if raise_error_when_error: |
| | raise ValueError('SMILES is empty.') |
| | else: |
| | mol = Chem.MolFromSmiles(smiles) |
| | if mol is None: |
| | mol = 'INVALID' |
| | invalid_labels.append(True) |
| | if raise_error_when_error: |
| | raise ValueError('SMILES is not valid: %s' % smiles) |
| | mol_list.append(mol) |
| | |
| | no_answer_labels = np.array(no_answer_labels) |
| | invalid_labels = np.arange(invalid_labels) |
| |
|
| | return mol_list, no_answer_labels, invalid_labels |
| |
|
| |
|
| | def judge_exact_match(pred_can_smiles_list, gold_can_smiles_list): |
| | assert len(pred_can_smiles_list) == len(gold_can_smiles_list) |
| | exact_match_labels = [] |
| | for pred_smiles, gold_smiles_list in zip(pred_can_smiles_list, gold_can_smiles_list): |
| | if pred_smiles is None: |
| | exact_match_labels.append(False) |
| | continue |
| | pred_smiles_inchi = get_molecule_id(pred_smiles) |
| | sample_exact_match = False |
| | for gold_smiles in gold_smiles_list: |
| | assert gold_smiles is not None |
| | gold_smiles_inchi = get_molecule_id(gold_smiles) |
| | if pred_smiles_inchi == gold_smiles_inchi: |
| | sample_exact_match = True |
| | break |
| | exact_match_labels.append(sample_exact_match) |
| | return np.array(exact_match_labels) |
| |
|
| |
|
| | def calculate_fingerprint_similarity(pred_mol_list, gold_mols_list, morgan_r=2): |
| | assert len(pred_mol_list) == len(gold_mols_list) |
| | MACCS_sims = [] |
| | morgan_sims = [] |
| | RDK_sims = [] |
| | for pred_mol, gold_mol_list in zip(pred_mol_list, gold_mols_list): |
| | if pred_mol is None or type(pred_mol) == str: |
| | raise ValueError(type(pred_mol)) |
| | tmp_MACCS, tmp_RDK, tmp_morgan = 0, 0, 0 |
| | for gold_mol in gold_mol_list: |
| | tmp_MACCS = max(tmp_MACCS, DataStructs.FingerprintSimilarity(MACCSkeys.GenMACCSKeys(gold_mol), MACCSkeys.GenMACCSKeys(pred_mol), metric=DataStructs.TanimotoSimilarity)) |
| | tmp_RDK = max(tmp_RDK, DataStructs.FingerprintSimilarity(Chem.RDKFingerprint(gold_mol), Chem.RDKFingerprint(pred_mol), metric=DataStructs.TanimotoSimilarity)) |
| | tmp_morgan = max(tmp_morgan, DataStructs.TanimotoSimilarity(AllChem.GetMorganFingerprint(gold_mol,morgan_r), AllChem.GetMorganFingerprint(pred_mol, morgan_r))) |
| | MACCS_sims.append(tmp_MACCS) |
| | RDK_sims.append(tmp_RDK) |
| | morgan_sims.append(tmp_morgan) |
| | maccs_sims_score = np.mean(MACCS_sims) |
| | rdk_sims_score = np.mean(RDK_sims) |
| | morgan_sims_score = np.mean(morgan_sims) |
| | return maccs_sims_score, rdk_sims_score, morgan_sims_score |
| |
|
| |
|
| | def judge_multiple_match(pred_can_smiles_list, golds_can_smiles_list): |
| | assert len(pred_can_smiles_list) == len(golds_can_smiles_list) |
| | subset_labels = [] |
| | intersection_labels = [] |
| | for pred_smiles, gold_smiles_list in zip(pred_can_smiles_list, golds_can_smiles_list): |
| | if pred_smiles is None: |
| | subset_labels.append(False) |
| | intersection_labels.append(False) |
| | continue |
| |
|
| | pred_ele_set = set() |
| | for smiles in pred_smiles.split('.'): |
| | pred_ele_set.add(get_molecule_id(smiles, remove_duplicate=False)) |
| |
|
| | intersection_label = False |
| | subset_label = False |
| | for gold_smiles in gold_smiles_list: |
| | assert gold_smiles is not None |
| | gold_ele_set = set() |
| | for smiles in gold_smiles.split('.'): |
| | gold_ele_set.add(get_molecule_id(smiles, remove_duplicate=False)) |
| |
|
| | if len(pred_ele_set & gold_ele_set) > 0: |
| | intersection_label = True |
| | g_p = gold_ele_set - pred_ele_set |
| | if len(g_p) >= 0 and len(pred_ele_set - gold_ele_set) == 0: |
| | subset_label = True |
| | break |
| | intersection_labels.append(intersection_label) |
| | subset_labels.append(subset_label) |
| | |
| | return intersection_labels, subset_labels |
| |
|
| |
|
| | def calculate_smiles_metrics( |
| | preds_smiles_list, |
| | golds_smiles_list, |
| | metrics=('exact_match', 'fingerprint') |
| | ): |
| | num_all = len(preds_smiles_list) |
| | assert num_all > 0 |
| | assert num_all == len(golds_smiles_list) |
| | k = len(preds_smiles_list[0]) |
| |
|
| | dk_pred_smiles_list_dict = {} |
| | dk_pred_no_answer_labels_dict = {} |
| | dk_pred_invalid_labels_dict = {} |
| | for dk in range(k): |
| | dk_pred_smiles_list_dict[dk] = [] |
| | dk_pred_no_answer_labels_dict[dk] = [] |
| | dk_pred_invalid_labels_dict[dk] = [] |
| | for pred_smiles_list in tqdm(preds_smiles_list): |
| | if pred_smiles_list is None: |
| | for dk in range(k): |
| | dk_pred_no_answer_labels_dict[dk].append(True) |
| | dk_pred_invalid_labels_dict[dk].append(False) |
| | dk_pred_smiles_list_dict[dk].append(None) |
| | continue |
| | assert len(pred_smiles_list) == k |
| | for dk, item in enumerate(pred_smiles_list): |
| | |
| | if item == '' or item is None: |
| | item = None |
| | dk_pred_no_answer_labels_dict[dk].append(True) |
| | dk_pred_invalid_labels_dict[dk].append(False) |
| | else: |
| | dk_pred_no_answer_labels_dict[dk].append(False) |
| | item = canonicalize_molecule_smiles(item) |
| | if item is None: |
| | dk_pred_invalid_labels_dict[dk].append(True) |
| | else: |
| | dk_pred_invalid_labels_dict[dk].append(False) |
| | dk_pred_smiles_list_dict[dk].append(item) |
| | |
| | new_list = [] |
| | for gold_smiles_list in tqdm(golds_smiles_list): |
| | sample_gold_smiles_list = [] |
| | for gold in gold_smiles_list: |
| | item = gold.strip() |
| | new_item = canonicalize_molecule_smiles(item, return_none_for_error=False) |
| | |
| | |
| | |
| | sample_gold_smiles_list.append(new_item) |
| | new_list.append(sample_gold_smiles_list) |
| | golds_smiles_list = new_list |
| |
|
| | metric_results = {'num_all': num_all} |
| |
|
| | tk_pred_no_answer_labels = np.array([True] * num_all) |
| | tk_pred_invalid_labels = np.array([True] * num_all) |
| | for dk in range(k): |
| | dk_no_answer_labels = dk_pred_no_answer_labels_dict[dk] |
| | dk_invalid_labels = dk_pred_invalid_labels_dict[dk] |
| | tk_pred_no_answer_labels = tk_pred_no_answer_labels & dk_no_answer_labels |
| | tk_pred_invalid_labels = tk_pred_invalid_labels & dk_invalid_labels |
| | metric_results['num_t%d_no_answer' % (dk + 1)] = tk_pred_no_answer_labels.sum().item() |
| | metric_results['num_t%d_invalid' % (dk + 1)] = tk_pred_invalid_labels.sum().item() |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | for metric in metrics: |
| | if metric == 'exact_match': |
| | tk_exact_match_labels = np.array([False] * num_all) |
| | for dk in range(k): |
| | dk_pred_smiles_list = dk_pred_smiles_list_dict[dk] |
| | dk_exact_match_labels = judge_exact_match(dk_pred_smiles_list, golds_smiles_list) |
| | tk_exact_match_labels = tk_exact_match_labels | dk_exact_match_labels |
| | metric_results['num_t%d_exact_match' % (dk + 1)] = tk_exact_match_labels.sum().item() |
| | elif metric == 'fingerprint': |
| | d1_pred_mol_list = [] |
| | gold_mols_list = [] |
| | for pred_smiles, gold_smiles_list, no_answer, invalid in zip(dk_pred_smiles_list_dict[0], golds_smiles_list, dk_pred_no_answer_labels_dict[0], dk_pred_invalid_labels_dict[0]): |
| | if pred_smiles is None or pred_smiles.strip() == '' or no_answer is True or invalid is True: |
| | continue |
| | pred_mol = Chem.MolFromSmiles(pred_smiles) |
| | |
| | |
| | assert pred_mol is not None, pred_smiles |
| | gold_mol_list = [] |
| | for gold_smiles in gold_smiles_list: |
| | gold_mol = Chem.MolFromSmiles(gold_smiles) |
| | |
| | |
| | assert gold_mol is not None, gold_smiles |
| | gold_mol_list.append(gold_mol) |
| | |
| | |
| | d1_pred_mol_list.append(pred_mol) |
| | gold_mols_list.append(gold_mol_list) |
| | maccs_sims_score, rdk_sims_score, morgan_sims_score = calculate_fingerprint_similarity(d1_pred_mol_list, gold_mols_list) |
| | metric_results['t1_maccs_fps'] = maccs_sims_score |
| | metric_results['t1_rdk_fps'] = rdk_sims_score |
| | metric_results['t1_morgan_fps'] = morgan_sims_score |
| | elif metric == 'multiple_match': |
| | tk_intersection_labels = np.array([False] * num_all) |
| | tk_subset_labels = np.array([False] * num_all) |
| | for dk in range(k): |
| | dk_intersection_labels, dk_subset_labels = judge_multiple_match(dk_pred_smiles_list_dict[dk], golds_smiles_list) |
| | tk_intersection_labels = tk_intersection_labels | dk_intersection_labels |
| | tk_subset_labels = tk_subset_labels | dk_subset_labels |
| | metric_results['num_t%d_subset' % (dk + 1)] = tk_intersection_labels.sum().item() |
| | metric_results['num_t%d_intersection' % (dk + 1)] = tk_intersection_labels.sum().item() |
| | else: |
| | raise ValueError(metric) |
| | |
| | return metric_results |
| |
|
| |
|
| | def judge_string_exact_match(pred_string_list, golds_string_list): |
| | exact_match_labels = [] |
| | for pred_string, gold_string_list in zip(pred_string_list, golds_string_list): |
| | exact_match = False |
| | for gold_string in gold_string_list: |
| | if pred_string == gold_string: |
| | exact_match = True |
| | break |
| | exact_match_labels.append(exact_match) |
| | return np.array(exact_match_labels) |
| |
|
| |
|
| | def judge_string_split_match(pred_string_list, golds_string_list, separater=';'): |
| | exact_match_labels = [] |
| | for pred_string, gold_string_list in zip(pred_string_list, golds_string_list): |
| | pred_item = tuple(sorted(pred_string.split(separater))) |
| | exact_match = False |
| | for gold_string in gold_string_list: |
| | gold_item = tuple(sorted(gold_string.split(separater))) |
| | if pred_item == gold_item: |
| | exact_match = True |
| | break |
| | exact_match_labels.append(exact_match) |
| | return np.array(exact_match_labels) |
| |
|
| |
|
| | def parse_molecule(molecular_formula): |
| | valid = re.match('([A-Za-z]\d*)+([\+\-]\d*)*$', molecular_formula) |
| | if valid is None: |
| | raise ValueError("Molecular formula \"%s\" is not valid." % molecular_formula) |
| |
|
| | stack = [defaultdict(int)] |
| |
|
| | def _parse_formula(formula, _stack): |
| |
|
| | |
| | r = None |
| |
|
| | |
| | atom = re.match(r'([A-Z][a-z]?)(\d+)?', formula) |
| | opening = re.match(r'[\(\[\{]', formula) |
| | closing = re.match(r'[\)\]\}](\d+)?', formula) |
| |
|
| | |
| | if atom: |
| | r = formula[len(atom.group()):] |
| | _stack[-1][atom.group(1)] += int(atom.group(2) or 1) |
| |
|
| | |
| | elif opening: |
| | r = formula[len(opening.group()):] |
| | _stack.append(defaultdict(int)) |
| |
|
| | |
| | elif closing: |
| | r = formula[len(closing.group()):] |
| | for (k, v) in _stack.pop().items(): |
| | _stack[-1][k] += v * int(closing.group(1) or 1) |
| |
|
| | |
| | if r: |
| | _parse_formula(r, _stack) |
| |
|
| | return dict(_stack[0]) |
| | |
| | result = _parse_formula(molecular_formula, stack) |
| |
|
| | charge = re.search('[\+\-]\d*', molecular_formula) |
| | if charge is not None: |
| | charge_str = charge.group() |
| | charge_type = charge_str[0] |
| | if len(charge_str) == 1: |
| | charge_num = 1 |
| | else: |
| | charge_num = int(charge_str[1:]) |
| | result[charge_type] = charge_num |
| |
|
| | return result |
| |
|
| |
|
| | def count_element_match(pred_formula_list, golds_formula_list): |
| | assert len(pred_formula_list) == len(golds_formula_list) |
| | ele_match_labels = [] |
| | ele_invalid_labels = [] |
| | for pred_formula, gold_formula_list in zip(pred_formula_list, golds_formula_list): |
| | if pred_formula == '' or pred_formula is None: |
| | ele_invalid_labels.append(False) |
| | ele_match_labels.append(False) |
| | continue |
| | try: |
| | pred_ele = parse_molecule(pred_formula) |
| | except KeyboardInterrupt: |
| | raise |
| | except: |
| | |
| | |
| | ele_invalid_labels.append(True) |
| | ele_match_labels.append(False) |
| | continue |
| | ele_invalid_labels.append(False) |
| | ele_match = False |
| | for gold_formula in gold_formula_list: |
| | gold_ele = parse_molecule(gold_formula) |
| | if pred_ele == gold_ele: |
| | ele_match = True |
| | break |
| | ele_match_labels.append(ele_match) |
| | return ele_match_labels, ele_invalid_labels |
| |
|
| |
|
| | def calculate_formula_metrics( |
| | preds_formula_list, |
| | golds_formula_list, |
| | metrics=('element_match',) |
| | ): |
| | """ |
| | Calculate metrics for molecular formula. Here we use element_match (equals to exact_match used in our paper) by default, which compares the atom numbers and ignore the orders. |
| | For example, C5H8 == H8C5. |
| | """ |
| | num_all = len(preds_formula_list) |
| | assert len(preds_formula_list) == len(golds_formula_list) |
| | try: |
| | k = len(preds_formula_list[0]) |
| | except IndexError: |
| | print(preds_formula_list) |
| | raise |
| | dk_pred_formula_list_dict = dict() |
| | for dk in range(k): |
| | dk_pred_formula_list_dict[dk] = [] |
| | for sample_formula_list in preds_formula_list: |
| | if sample_formula_list is None: |
| | for dk in range(k): |
| | dk_pred_formula_list_dict[dk].append('') |
| | continue |
| | assert len(sample_formula_list) == k |
| | for dk in range(k): |
| | item = sample_formula_list[dk] |
| | dk_pred_formula_list_dict[dk].append(item) |
| | golds_formula_list = [[small_item.strip() for small_item in item] for item in golds_formula_list] |
| | new_golds_formula_list = [] |
| | for item in golds_formula_list: |
| | new_item = [] |
| | for small_item in item: |
| | small_item = small_item.strip() |
| | assert small_item != '' |
| | new_item.append(small_item) |
| | new_golds_formula_list.append(new_item) |
| | golds_formula_list = new_golds_formula_list |
| |
|
| |
|
| | metric_results = {'num_all': num_all} |
| |
|
| | tk_no_answer_labels = np.array([True] * num_all) |
| | for dk in range(k): |
| | dk_pred_formula_list = dk_pred_formula_list_dict[dk] |
| | dk_no_answer_labels = [] |
| | for item in dk_pred_formula_list: |
| | if item == '' or item is None: |
| | dk_no_answer_labels.append(True) |
| | else: |
| | dk_no_answer_labels.append(False) |
| | dk_no_answer_labels = np.array(dk_no_answer_labels) |
| | tk_no_answer_labels = tk_no_answer_labels & dk_no_answer_labels |
| | metric_results['num_t%d_no_answer' % (dk + 1)] = tk_no_answer_labels.sum().item() |
| |
|
| | for metric in metrics: |
| | if metric == 'exact_match': |
| | tk_exact_match_labels = np.array([False] * num_all) |
| | for dk in range(k): |
| | dk_pred_formula_list = dk_pred_formula_list_dict[dk] |
| | dk_exact_match_labels = judge_string_exact_match(dk_pred_formula_list, golds_formula_list) |
| | tk_exact_match_labels = tk_exact_match_labels | dk_exact_match_labels |
| | metric_results['num_t%d_exact_match' % (dk + 1)] = tk_exact_match_labels.sum().item() |
| | elif metric == 'element_match': |
| | tk_ele_match_labels = np.array([False] * num_all) |
| | tk_formula_invalid_labels = np.array([True] * num_all) |
| | for dk in range(k): |
| | dk_pred_formula_list = dk_pred_formula_list_dict[dk] |
| | dk_ele_match_labels, dk_formula_invalid_labels = count_element_match(dk_pred_formula_list, golds_formula_list) |
| | tk_ele_match_labels = tk_ele_match_labels | dk_ele_match_labels |
| | tk_formula_invalid_labels = tk_formula_invalid_labels & dk_formula_invalid_labels |
| | metric_results['num_t%d_ele_match' % (dk + 1)] = tk_ele_match_labels.sum().item() |
| | metric_results['num_t%d_formula_invalid' % (dk + 1)] = tk_formula_invalid_labels.sum().item() |
| | elif metric == 'split_match': |
| | tk_exact_match_labels = np.array([False] * num_all) |
| | for dk in range(k): |
| | dk_pred_formula_list = dk_pred_formula_list_dict[dk] |
| | dk_exact_match_labels = judge_string_split_match(dk_pred_formula_list, golds_formula_list) |
| | tk_exact_match_labels = tk_exact_match_labels | dk_exact_match_labels |
| | metric_results['num_t%d_split_match' % (dk + 1)] = tk_exact_match_labels.sum().item() |
| | else: |
| | raise ValueError(metric) |
| | |
| | return metric_results |
| |
|
| |
|
| | def calculate_text_metrics(pred_text_list, gold_text_list, text_model='/AIRvePFS/dair/fsq-data/experiments/biomedgpt/biomedgpt_qwen/ckpts/text_ckpts/scibert_scivocab_uncased', text_trunc_length=512): |
| | assert len(pred_text_list) == len(gold_text_list) |
| | pred_text_list = [(item[0].strip() if item is not None else '') for item in pred_text_list] |
| | gold_text_list = [item[0].strip() for item in gold_text_list] |
| |
|
| | num_no_answer = 0 |
| | for pred_formula in pred_text_list: |
| | if pred_formula == '': |
| | num_no_answer += 1 |
| |
|
| | text_tokenizer = BertTokenizerFast.from_pretrained(text_model) |
| |
|
| | meteor_scores = [] |
| |
|
| | references = [] |
| | hypotheses = [] |
| |
|
| | for i, (gt, out) in enumerate(zip(gold_text_list, pred_text_list)): |
| | if out == '': |
| | continue |
| |
|
| | gt_tokens = text_tokenizer.tokenize(gt, truncation=True, max_length=text_trunc_length, |
| | padding='max_length') |
| | gt_tokens = list(filter(('[PAD]').__ne__, gt_tokens)) |
| | gt_tokens = list(filter(('[CLS]').__ne__, gt_tokens)) |
| | gt_tokens = list(filter(('[SEP]').__ne__, gt_tokens)) |
| |
|
| | out_tokens = text_tokenizer.tokenize(out, truncation=True, max_length=text_trunc_length, |
| | padding='max_length') |
| | out_tokens = list(filter(('[PAD]').__ne__, out_tokens)) |
| | out_tokens = list(filter(('[CLS]').__ne__, out_tokens)) |
| | out_tokens = list(filter(('[SEP]').__ne__, out_tokens)) |
| |
|
| | references.append([gt_tokens]) |
| | hypotheses.append(out_tokens) |
| |
|
| | mscore = meteor_score([gt_tokens], out_tokens) |
| | meteor_scores.append(mscore) |
| |
|
| | bleu2 = corpus_bleu(references, hypotheses, weights=(.5,.5)) |
| | bleu4 = corpus_bleu(references, hypotheses, weights=(.25,.25,.25,.25)) |
| |
|
| | _meteor_score = np.mean(meteor_scores) |
| |
|
| | scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL']) |
| |
|
| | rouge_scores = [] |
| |
|
| | references = [] |
| | hypotheses = [] |
| |
|
| | for i, (gt, out) in enumerate(zip(gold_text_list, pred_text_list)): |
| | if out == '': |
| | continue |
| |
|
| | rs = scorer.score(out, gt) |
| | rouge_scores.append(rs) |
| |
|
| | rouge_1 = np.mean([rs['rouge1'].fmeasure for rs in rouge_scores]) |
| | rouge_2 = np.mean([rs['rouge2'].fmeasure for rs in rouge_scores]) |
| | rouge_l = np.mean([rs['rougeL'].fmeasure for rs in rouge_scores]) |
| |
|
| | result = { |
| | 'num_all': len(pred_text_list), |
| | 'num_no_answer': num_no_answer, |
| | 'bleu2': bleu2, |
| | 'bleu4': bleu4, |
| | 'rouge_1': rouge_1, |
| | 'rouge_2': rouge_2, |
| | 'rouge_l': rouge_l, |
| | 'meteor_score': _meteor_score, |
| | } |
| |
|
| | return result |
| |
|
| |
|
| | def calculate_number_metrics(pred_text_list, gold_text_list): |
| | assert len(pred_text_list) == len(gold_text_list) |
| | num_all = len(pred_text_list) |
| | metrics = {} |
| | metrics['num_all'] = num_all |
| | num_no_answer = 0 |
| | num_invalid = 0 |
| | new_pred_text_list, new_gold_text_list = [], [] |
| | for (pred_item, gold_item) in zip(pred_text_list, gold_text_list): |
| | if pred_item is None: |
| | num_no_answer += 1 |
| | continue |
| | assert len(pred_item) == 1 |
| | assert len(gold_item) == 1 |
| | pred_item = pred_item[0] |
| | gold_item = gold_item[0] |
| | if pred_item == '': |
| | num_no_answer += 1 |
| | continue |
| | try: |
| | pred_item = float(pred_item) |
| | except (SyntaxError, ValueError): |
| | |
| | num_invalid += 1 |
| | continue |
| | try: |
| | gold_item = float(gold_item) |
| | except: |
| | import pdb |
| | pdb.set_trace() |
| | new_pred_text_list.append(pred_item) |
| | new_gold_text_list.append(gold_item) |
| | |
| | new_pred_text_list = np.array(new_pred_text_list) |
| | new_gold_text_list = np.array(new_gold_text_list) |
| | score = np.sqrt(((new_pred_text_list - new_gold_text_list) ** 2).mean()) |
| | |
| | metrics['num_no_answer'] = num_no_answer |
| | metrics['num_invalid'] = num_invalid |
| | metrics['RMSE'] = score |
| |
|
| | return metrics |
| |
|
| |
|
| | def calculate_boolean_metrics(pred_text_list, gold_text_list): |
| | assert len(pred_text_list) == len(gold_text_list) |
| | num_all = len(pred_text_list) |
| | metrics = {} |
| | metrics['num_all'] = num_all |
| | num_no_answer = 0 |
| | num_invalid = 0 |
| | num_correct = 0 |
| | new_pred_text_list, new_gold_text_list = [], [] |
| | for (pred_item, gold_item) in zip(pred_text_list, gold_text_list): |
| | if pred_item is None or pred_item == '': |
| | num_no_answer += 1 |
| | continue |
| | assert len(pred_item) == 1 |
| | assert len(gold_item) == 1 |
| | pred_item = pred_item[0].strip().lower() |
| | gold_item = gold_item[0].strip().lower() |
| | if pred_item == '': |
| | num_no_answer += 1 |
| | continue |
| | if pred_item not in ('yes', 'no'): |
| | num_invalid += 1 |
| | continue |
| | pred_item = 1 if pred_item == 'yes' else 0 |
| | gold_item = 1 if gold_item == 'yes' else 0 |
| | new_pred_text_list.append(pred_item) |
| | new_gold_text_list.append(gold_item) |
| | if gold_item == pred_item: |
| | num_correct += 1 |
| |
|
| | metrics['num_no_answer'] = num_no_answer |
| | metrics['num_invalid'] = num_invalid |
| | metrics['num_correct'] = num_correct |
| |
|
| | |
| |
|
| | new_gold_text_list = np.array(new_gold_text_list) |
| | new_pred_text_list = np.array(new_pred_text_list) |
| |
|
| | macro_roc_auc_score = roc_auc_score(new_gold_text_list, new_pred_text_list) |
| | f1 = f1_score(new_gold_text_list, new_pred_text_list) |
| | metrics['roc_auc_score'] = macro_roc_auc_score |
| | metrics['precision'] = precision_score(new_gold_text_list, new_pred_text_list) |
| | metrics['recall'] = recall_score(new_gold_text_list, new_pred_text_list) |
| | metrics['f1_score'] = f1 |
| |
|
| | no_mask = (new_gold_text_list == 0) |
| | new_gold_text_list[no_mask] = -1 |
| | no_mask = (new_pred_text_list == 0) |
| | new_pred_text_list[no_mask] = -1 |
| | metrics['mcc'] = matthews_corrcoef(new_gold_text_list, new_pred_text_list) |
| |
|
| | return metrics |