API / app.py
tuyen4656789's picture
Update app.py
7f82e14 verified
from fastapi import FastAPI, Request
import json
from transformers import BertForQuestionAnswering, BertTokenizer
import torch
app = FastAPI()
# Load the model and tokenizer once, globally
bert_model = BertForQuestionAnswering.from_pretrained(
'bert-large-uncased-whole-word-masking-finetuned-squad',
ignore_mismatched_sizes=True
)
bert_tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
def get_answer_using_bert(question, reference_text):
# Tokenize input text
input_ids = bert_tokenizer.encode(question, reference_text)
input_tokens = bert_tokenizer.convert_ids_to_tokens(input_ids)
# Locate [SEP] token
sep_location = input_ids.index(bert_tokenizer.sep_token_id)
first_seg_len, second_seg_len = sep_location + 1, len(input_ids) - (sep_location + 1)
seg_embedding = [0] * first_seg_len + [1] * second_seg_len
# Pass through the model
model_scores = bert_model(torch.tensor([input_ids]), token_type_ids=torch.tensor([seg_embedding]))
ans_start_loc, ans_end_loc = torch.argmax(model_scores[0]), torch.argmax(model_scores[1])
result = ' '.join(input_tokens[ans_start_loc:ans_end_loc + 1])
# Clean up result
result = result.replace(' ##', '')
return result
@app.post("/questionAnswering")
async def questionAnswering(request: Request):
try:
json_data = await request.json()
query = json_data['query']
context_list = json_data['context_list']
result = []
# Loop over contexts
for val in context_list:
context = val['context'].replace("\n", " ")
answer_json_final = {
'answer': get_answer_using_bert(query, context),
'id': val['id'],
'question': query
}
result.append(answer_json_final)
return {"results": result}
except Exception as e:
return {"Error": str(e)}