vishnuraggav's picture
First
41e2723
raw
history blame contribute delete
940 Bytes
''' Import Modules '''
from transformers import BertTokenizer, BertForQuestionAnswering
import gradio as gr
import torch
''' Setup '''
tokenizer = BertTokenizer.from_pretrained("./bert_tokenizer")
model = BertForQuestionAnswering.from_pretrained("./bert_model")
def main(context, question):
inputs = tokenizer(question, context, return_tensors='pt', add_special_tokens=True)
with torch.inference_mode():
outputs = model(**inputs)
start_scores = outputs.start_logits
end_scores = outputs.end_logits
start_index = (torch.argmax(torch.softmax(start_scores, dim=1)))
end_index = (torch.argmax(torch.softmax(end_scores, dim=1)))
answer_tokens = inputs['input_ids'][0][start_index:end_index + 1]
answer = tokenizer.decode(answer_tokens)
return answer
''' App Interface '''
app = gr.Interface(
fn=main,
inputs=[gr.Textbox(), gr.Textbox()],
outputs=gr.Textbox()
)
app.launch()