Spaces:
Sleeping
Sleeping
File size: 1,949 Bytes
d74427e ef53edc c9b9353 46cc82d ef53edc d74427e 46cc82d c9b9353 7f82e14 c9b9353 46cc82d c9b9353 46cc82d c9b9353 46cc82d c9b9353 46cc82d c9b9353 46cc82d d74427e ef53edc d74427e ef53edc c9b9353 ef53edc c9b9353 ef53edc d74427e ef53edc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | from fastapi import FastAPI, Request
import json
from transformers import BertForQuestionAnswering, BertTokenizer
import torch
app = FastAPI()
# Load the model and tokenizer once, globally
bert_model = BertForQuestionAnswering.from_pretrained(
'bert-large-uncased-whole-word-masking-finetuned-squad',
ignore_mismatched_sizes=True
)
bert_tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
def get_answer_using_bert(question, reference_text):
# Tokenize input text
input_ids = bert_tokenizer.encode(question, reference_text)
input_tokens = bert_tokenizer.convert_ids_to_tokens(input_ids)
# Locate [SEP] token
sep_location = input_ids.index(bert_tokenizer.sep_token_id)
first_seg_len, second_seg_len = sep_location + 1, len(input_ids) - (sep_location + 1)
seg_embedding = [0] * first_seg_len + [1] * second_seg_len
# Pass through the model
model_scores = bert_model(torch.tensor([input_ids]), token_type_ids=torch.tensor([seg_embedding]))
ans_start_loc, ans_end_loc = torch.argmax(model_scores[0]), torch.argmax(model_scores[1])
result = ' '.join(input_tokens[ans_start_loc:ans_end_loc + 1])
# Clean up result
result = result.replace(' ##', '')
return result
@app.post("/questionAnswering")
async def questionAnswering(request: Request):
try:
json_data = await request.json()
query = json_data['query']
context_list = json_data['context_list']
result = []
# Loop over contexts
for val in context_list:
context = val['context'].replace("\n", " ")
answer_json_final = {
'answer': get_answer_using_bert(query, context),
'id': val['id'],
'question': query
}
result.append(answer_json_final)
return {"results": result}
except Exception as e:
return {"Error": str(e)}
|