| ''' Import Modules ''' |
| from transformers import BertTokenizer, BertForQuestionAnswering |
| import gradio as gr |
| import torch |
|
|
| ''' Setup ''' |
| tokenizer = BertTokenizer.from_pretrained("./bert_tokenizer") |
| model = BertForQuestionAnswering.from_pretrained("./bert_model") |
|
|
| def main(context, question): |
| inputs = tokenizer(question, context, return_tensors='pt', add_special_tokens=True) |
|
|
| with torch.inference_mode(): |
| outputs = model(**inputs) |
| start_scores = outputs.start_logits |
| end_scores = outputs.end_logits |
|
|
| start_index = (torch.argmax(torch.softmax(start_scores, dim=1))) |
| end_index = (torch.argmax(torch.softmax(end_scores, dim=1))) |
|
|
| answer_tokens = inputs['input_ids'][0][start_index:end_index + 1] |
| answer = tokenizer.decode(answer_tokens) |
| return answer |
|
|
| ''' App Interface ''' |
| app = gr.Interface( |
| fn=main, |
| inputs=[gr.Textbox(), gr.Textbox()], |
| outputs=gr.Textbox() |
| ) |
|
|
| app.launch() |
|
|