Spaces:
Build error
Build error
| # Copyright (c) 2022, Lawrence Livermore National Security, LLC. | |
| # All rights reserved. | |
| # See the top-level LICENSE and NOTICE files for details. | |
| # LLNL-CODE-838964 | |
| # SPDX-License-Identifier: Apache-2.0-with-LLVM-exception | |
| import sys | |
| import json | |
| from math import ceil | |
| import torch | |
| import numpy as np | |
| from torch import tensor | |
| from torch.nn.functional import log_softmax | |
| from torch.distributions.categorical import Categorical | |
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| # load UnifiedQA onto device | |
| model_name = "allenai/unifiedqa-v2-t5-large-1363200" | |
| tokenizer = T5Tokenizer.from_pretrained(model_name) | |
| model = T5ForConditionalGeneration.from_pretrained(model_name) | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| model.to(device) | |
| def get_inputs(contexts_json, ranked_contexts_json): | |
| with open(contexts_json, 'rt') as fp: | |
| contexts = json.load(fp) | |
| with open(ranked_contexts_json, 'rt') as fp: | |
| ranked_contexts = json.load(fp) | |
| question_id = list(ranked_contexts.keys())[0] | |
| # assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}' | |
| question = ranked_contexts[question_id]['text'] | |
| context_ids_sorted = ranked_contexts[question_id]['ranks'] | |
| context_scores = ranked_contexts[question_id]['scores'] | |
| contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted] | |
| # returns the question (str) and its contexts (sequence) | |
| return question, contexts, context_scores | |
| def get_tokens(text, tokenizer, max_tokens): | |
| return tokenizer.encode_plus(text, return_tensors='pt', max_length=max_tokens, padding='max_length', truncation=True)['input_ids'] | |
| def prepare_inputs(tokenizer, max_tokens, context, question): | |
| input_str = f'{question} \\n {context}' | |
| inputs = get_tokens(input_str, tokenizer, max_tokens) | |
| return inputs | |
| def get_outputs(model, tokenizer, input_tokens, max_tokens): | |
| output_dict = model.generate(input_tokens, output_scores=True, return_dict_in_generate=True, **{'max_length': max_tokens}) | |
| pred_tokens = output_dict['sequences'].squeeze().tolist() | |
| # initialize metrics | |
| logit_entropy = [] | |
| sentence_probs = [] | |
| # accumulate metrics over logit_sequence | |
| logit_sequence = output_dict['scores'][:-1] # discard end token | |
| for logit in logit_sequence: | |
| log_probs = log_softmax(logit, dim=-1) | |
| # update metrics | |
| logit_entropy.append(Categorical(log_probs.exp()).entropy()) | |
| sentence_probs.append(log_probs.max()) | |
| # finish metrics calculation | |
| logit_entropy = tensor(logit_entropy) | |
| sentence_probs = tensor(sentence_probs) | |
| entropy = logit_entropy.mean() | |
| sentence_std = 0 if len(logit_sequence) == 1 else sentence_probs.std(unbiased=True).exp() | |
| # use entropy * sentence_std as uncertainty | |
| uncertainty = (entropy * sentence_std).item() | |
| # convert answer tokens to str | |
| pred_str = tokenizer.decode(pred_tokens, skip_special_tokens=True).lower() | |
| return pred_str, uncertainty | |
| # k_percent: percentage of contexts to use, cannot be less than min_k or greater than max_k | |
| # min_k: minimum number of contexts to use, if possible. Setting this too small reduces recall | |
| # max_k: maximum number of contexts to use. Setting this too big reduces precision | |
| # recommended uncertainty thresholds are 2,3,4, and 5. The lower the threshold, the more aggressive the filtering | |
| def run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=0.1, min_k=10, max_k=25, uncertainty_thresh=3): | |
| k = min(max(ceil(k_percent * len(contexts)), min_k), max_k) | |
| contexts = contexts[:k] | |
| context_scores = context_scores[:k] | |
| # iterate through top-k contexts | |
| answers = [] | |
| uncertainty = [] | |
| for context in contexts: | |
| input_tokens = prepare_inputs(tokenizer, 512, context, question).to(device) | |
| pred_str, uncertainty_1 = get_outputs(model, tokenizer, input_tokens, 512) | |
| answers.append(pred_str) | |
| uncertainty.append(uncertainty_1) | |
| # contexts = np.array(contexts) | |
| # answers = np.array(answers) | |
| # uncertainty = np.array(uncertainty) | |
| # sort by uncertainty, ascending order | |
| # order = np.argsort(uncertainty) | |
| # contexts = contexts[order] | |
| # answers = answers[order] | |
| # uncertainty = uncertainty[order] | |
| # init lists for threshed answers | |
| # weak_contexts = [] | |
| # weak_answers = [] | |
| # weak_uncertainty = [] | |
| # filter by uncertainty | |
| # if len(answers) > min_k: | |
| # weak = np.argwhere(uncertainty > uncertainty_thresh) # exceeds threshold | |
| # weak_contexts = contexts[weak].tolist() | |
| # weak_answers = answers[weak].tolist() | |
| # weak_uncertainty = uncertainty[weak].tolist() | |
| # strong = np.argwhere(uncertainty <= uncertainty_thresh) # within threshold | |
| # contexts = contexts[strong] | |
| # answers = answers[strong] | |
| # uncertainty = uncertainty[strong] | |
| # contexts = contexts.tolist() | |
| # answers = answers.tolist() | |
| # uncertainty = uncertainty.tolist() | |
| # return {'contexts': contexts, 'answers': answers, 'uncertainty': uncertainty}, \ | |
| # {'contexts': weak_contexts, 'answers': weak_answers, 'uncertainty': weak_uncertainty} | |
| return {'contexts': contexts, 'answers': answers, 'context_scores':context_scores, 'uncertainty': uncertainty} | |
| def get_qa_results(contexts_json, ranked_contexts_json, topk): | |
| # extract question and contexts from json | |
| question, contexts, context_scores = get_inputs(contexts_json, ranked_contexts_json) | |
| # infer answers | |
| with torch.inference_mode(True): | |
| # strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent) | |
| qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk) | |
| return qa_results | |
| def get_qa_results_in_memory(contexts, ranked_contexts, topk): | |
| question_id = list(ranked_contexts.keys())[0] | |
| # assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}' | |
| question = ranked_contexts[question_id]['text'] | |
| context_ids_sorted = ranked_contexts[question_id]['ranks'] | |
| context_scores = ranked_contexts[question_id]['scores'] | |
| contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted] | |
| # infer answers | |
| with torch.inference_mode(True): | |
| # strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent) | |
| qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk) | |
| return qa_results | |
| def load_custom_model(finetuned_model_path): | |
| global tokenizer | |
| global model | |
| # load UnifiedQA onto device | |
| tokenizer = T5Tokenizer.from_pretrained(finetuned_model_path) | |
| model = T5ForConditionalGeneration.from_pretrained(finetuned_model_path) | |
| model.to(device) | |
| def get_qa_results_in_memory_finetuned_unifiedqa(question, context_scores, contexts, topk): | |
| # infer answers | |
| with torch.inference_mode(True): | |
| # strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent) | |
| qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk) | |
| return qa_results | |