Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| import json | |
| from transformers import BertForQuestionAnswering, BertTokenizer | |
| import torch | |
| app = FastAPI() | |
| # Load the model and tokenizer once, globally | |
| bert_model = BertForQuestionAnswering.from_pretrained( | |
| 'bert-large-uncased-whole-word-masking-finetuned-squad', | |
| ignore_mismatched_sizes=True | |
| ) | |
| bert_tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad') | |
| def get_answer_using_bert(question, reference_text): | |
| # Tokenize input text | |
| input_ids = bert_tokenizer.encode(question, reference_text) | |
| input_tokens = bert_tokenizer.convert_ids_to_tokens(input_ids) | |
| # Locate [SEP] token | |
| sep_location = input_ids.index(bert_tokenizer.sep_token_id) | |
| first_seg_len, second_seg_len = sep_location + 1, len(input_ids) - (sep_location + 1) | |
| seg_embedding = [0] * first_seg_len + [1] * second_seg_len | |
| # Pass through the model | |
| model_scores = bert_model(torch.tensor([input_ids]), token_type_ids=torch.tensor([seg_embedding])) | |
| ans_start_loc, ans_end_loc = torch.argmax(model_scores[0]), torch.argmax(model_scores[1]) | |
| result = ' '.join(input_tokens[ans_start_loc:ans_end_loc + 1]) | |
| # Clean up result | |
| result = result.replace(' ##', '') | |
| return result | |
| async def questionAnswering(request: Request): | |
| try: | |
| json_data = await request.json() | |
| query = json_data['query'] | |
| context_list = json_data['context_list'] | |
| result = [] | |
| # Loop over contexts | |
| for val in context_list: | |
| context = val['context'].replace("\n", " ") | |
| answer_json_final = { | |
| 'answer': get_answer_using_bert(query, context), | |
| 'id': val['id'], | |
| 'question': query | |
| } | |
| result.append(answer_json_final) | |
| return {"results": result} | |
| except Exception as e: | |
| return {"Error": str(e)} | |