Spaces:
Paused
Paused
| import gradio as gr | |
| from transformers import BertForQuestionAnswering | |
| model = BertForQuestionAnswering.from_pretrained("bert-base-uncased") | |
| def get_prediction(context, question): | |
| inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device) | |
| outputs = model(**inputs) | |
| answer_start = torch.argmax(outputs[0]) | |
| answer_end = torch.argmax(outputs[1]) + 1 | |
| answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end])) | |
| return answer | |
| def normalize_text(s): | |
| """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps.""" | |
| import string, re | |
| def remove_articles(text): | |
| regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) | |
| return re.sub(regex, " ", text) | |
| def white_space_fix(text): | |
| return " ".join(text.split()) | |
| def remove_punc(text): | |
| exclude = set(string.punctuation) | |
| return "".join(ch for ch in text if ch not in exclude) | |
| def lower(text): | |
| return text.lower() | |
| return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
| def exact_match(prediction, truth): | |
| return bool(normalize_text(prediction) == normalize_text(truth)) | |
| def compute_f1(prediction, truth): | |
| pred_tokens = normalize_text(prediction).split() | |
| truth_tokens = normalize_text(truth).split() | |
| # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise | |
| if len(pred_tokens) == 0 or len(truth_tokens) == 0: | |
| return int(pred_tokens == truth_tokens) | |
| common_tokens = set(pred_tokens) & set(truth_tokens) | |
| # if there are no common tokens then f1 = 0 | |
| if len(common_tokens) == 0: | |
| return 0 | |
| prec = len(common_tokens) / len(pred_tokens) | |
| rec = len(common_tokens) / len(truth_tokens) | |
| return round(2 * (prec * rec) / (prec + rec), 2) | |
| def question_answer(context, question): | |
| prediction = get_prediction(context,question) | |
| return prediction | |
| def greet(texts): | |
| question = texts[:len(texts)] | |
| answer = texts[len(texts):] | |
| for question, answer in texts: | |
| question_answer(context, question) | |
| return texts | |
| iface = gr.Interface(fn=greet, inputs="text", outputs="text") | |
| iface.launch() |