Spaces:
Runtime error
Runtime error
| import datasets | |
| from llm.qa_agent import QnAAgent | |
| validation_dataset = datasets.load_dataset( | |
| "trivia_qa", "rc", split="test" | |
| ) # remove [:5%] to run on full validation set | |
| PUNCTUATION_SET_TO_EXCLUDE = set("".join(["β", "β", "Β΄", "`", ".", ",", "-", '"'])) | |
| qna_agent = QnAAgent() | |
| def get_sub_answers(answers, begin=0, end=None): | |
| return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1] | |
| def expand_to_aliases(given_answers, make_sub_answers=False): | |
| if make_sub_answers: | |
| # if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word | |
| # *e.g.* if the correct answer contains a prefix such as "the", or "a" | |
| given_answers = ( | |
| given_answers | |
| + get_sub_answers(given_answers, begin=1) | |
| + get_sub_answers(given_answers, end=-1) | |
| ) | |
| answers = [] | |
| for answer in given_answers: | |
| alias = answer.replace("_", " ").lower() | |
| alias = "".join( | |
| c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias | |
| ) | |
| answers.append(" ".join(alias.split()).strip()) | |
| return set(answers) | |
| def evaluate(example): | |
| # get answer from QnA agent | |
| answer_without_context = qna_agent.get_answer(example["question"], use_context=False) | |
| answer_with_context = qna_agent.get_answer(example["question"], use_context=True) | |
| example["output"] = answer_without_context | |
| example["output_context"] = answer_with_context | |
| example["targets"] = example["answer"]["aliases"] | |
| answers = expand_to_aliases(example["targets"], make_sub_answers=True) | |
| predictions = expand_to_aliases([example["output"]]) | |
| preditions_with_context = expand_to_aliases([example["output_context"]]) | |
| # if there is a common element, it's a match | |
| example["match"] = len(list(answers & predictions)) > 0 | |
| example["match_context"] = len(list(answers & preditions_with_context)) > 0 | |
| return example | |
| results = validation_dataset.map(evaluate) | |
| print("Exact Match (EM) without context: {:.2f}".format(100 * sum(results['match'])/len(results))) | |
| print("Exact Match (EM) with context: {:.2f}".format(100 * sum(results['match_context'])/len(results))) | |