from fastapi import FastAPI, Request import json from transformers import BertForQuestionAnswering, BertTokenizer import torch app = FastAPI() # Load the model and tokenizer once, globally bert_model = BertForQuestionAnswering.from_pretrained( 'bert-large-uncased-whole-word-masking-finetuned-squad', ignore_mismatched_sizes=True ) bert_tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad') def get_answer_using_bert(question, reference_text): # Tokenize input text input_ids = bert_tokenizer.encode(question, reference_text) input_tokens = bert_tokenizer.convert_ids_to_tokens(input_ids) # Locate [SEP] token sep_location = input_ids.index(bert_tokenizer.sep_token_id) first_seg_len, second_seg_len = sep_location + 1, len(input_ids) - (sep_location + 1) seg_embedding = [0] * first_seg_len + [1] * second_seg_len # Pass through the model model_scores = bert_model(torch.tensor([input_ids]), token_type_ids=torch.tensor([seg_embedding])) ans_start_loc, ans_end_loc = torch.argmax(model_scores[0]), torch.argmax(model_scores[1]) result = ' '.join(input_tokens[ans_start_loc:ans_end_loc + 1]) # Clean up result result = result.replace(' ##', '') return result @app.post("/questionAnswering") async def questionAnswering(request: Request): try: json_data = await request.json() query = json_data['query'] context_list = json_data['context_list'] result = [] # Loop over contexts for val in context_list: context = val['context'].replace("\n", " ") answer_json_final = { 'answer': get_answer_using_bert(query, context), 'id': val['id'], 'question': query } result.append(answer_json_final) return {"results": result} except Exception as e: return {"Error": str(e)}