File size: 1,949 Bytes
d74427e
ef53edc
c9b9353
46cc82d
ef53edc
d74427e
46cc82d
c9b9353
7f82e14
 
 
 
c9b9353
46cc82d
c9b9353
 
46cc82d
 
 
c9b9353
46cc82d
 
 
 
c9b9353
46cc82d
 
 
 
c9b9353
46cc82d
 
 
d74427e
 
ef53edc
d74427e
ef53edc
 
 
c9b9353
 
ef53edc
c9b9353
 
 
 
 
 
ef53edc
 
d74427e
ef53edc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from fastapi import FastAPI, Request
import json
from transformers import BertForQuestionAnswering, BertTokenizer
import torch

app = FastAPI()

# Load the model and tokenizer once, globally
bert_model = BertForQuestionAnswering.from_pretrained(
    'bert-large-uncased-whole-word-masking-finetuned-squad',
    ignore_mismatched_sizes=True
)
bert_tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

def get_answer_using_bert(question, reference_text):
    # Tokenize input text
    input_ids = bert_tokenizer.encode(question, reference_text)
    input_tokens = bert_tokenizer.convert_ids_to_tokens(input_ids)

    # Locate [SEP] token
    sep_location = input_ids.index(bert_tokenizer.sep_token_id)
    first_seg_len, second_seg_len = sep_location + 1, len(input_ids) - (sep_location + 1)
    seg_embedding = [0] * first_seg_len + [1] * second_seg_len

    # Pass through the model
    model_scores = bert_model(torch.tensor([input_ids]), token_type_ids=torch.tensor([seg_embedding]))
    ans_start_loc, ans_end_loc = torch.argmax(model_scores[0]), torch.argmax(model_scores[1])
    result = ' '.join(input_tokens[ans_start_loc:ans_end_loc + 1])

    # Clean up result
    result = result.replace(' ##', '')
    return result

@app.post("/questionAnswering")
async def questionAnswering(request: Request):
    try:
        json_data = await request.json()
        query = json_data['query']
        context_list = json_data['context_list']
        result = []

        # Loop over contexts
        for val in context_list:
            context = val['context'].replace("\n", " ")
            answer_json_final = {
                'answer': get_answer_using_bert(query, context),
                'id': val['id'],
                'question': query
            }
            result.append(answer_json_final)

        return {"results": result}

    except Exception as e:
        return {"Error": str(e)}