|
|
import streamlit as st |
|
|
import torch |
|
|
from transformers import BertForQuestionAnswering, BertTokenizer |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="BERT Question Answering System", layout="centered") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased') |
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
return model, tokenizer |
|
|
|
|
|
model, tokenizer = load_model() |
|
|
|
|
|
|
|
|
def get_answer(question, context): |
|
|
inputs = tokenizer.encode_plus(question, context, return_tensors='pt', max_length=512, truncation=True) |
|
|
input_ids = inputs['input_ids'].tolist()[0] |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
answer_start = torch.argmax(outputs.start_logits) |
|
|
answer_end = torch.argmax(outputs.end_logits) + 1 |
|
|
|
|
|
answer = tokenizer.convert_tokens_to_string( |
|
|
tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]) |
|
|
) |
|
|
|
|
|
return answer |
|
|
|
|
|
|
|
|
st.title("π€ BERT Question Answering System") |
|
|
st.write("This app uses BERT to answer questions based on a given context.") |
|
|
|
|
|
|
|
|
context = st.text_area("π Enter the context/passage:", height=200) |
|
|
question = st.text_input("β Ask a question about the context:") |
|
|
|
|
|
|
|
|
if st.button("Get Answer"): |
|
|
if not context or not question: |
|
|
st.warning("Please provide both a context and a question.") |
|
|
else: |
|
|
try: |
|
|
answer = get_answer(question, context) |
|
|
if answer: |
|
|
st.success(f"π Answer: {answer}") |
|
|
else: |
|
|
st.warning("No answer found in the given context.") |
|
|
except Exception as e: |
|
|
st.error(f"An error occurred: {str(e)}") |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
.stTextInput input, .stTextArea textarea { |
|
|
font-size: 16px !important; |
|
|
} |
|
|
.stButton button { |
|
|
background-color: #4CAF50; |
|
|
color: white; |
|
|
font-weight: bold; |
|
|
padding: 0.5rem 1rem; |
|
|
border-radius: 5px; |
|
|
} |
|
|
.stButton button:hover { |
|
|
background-color: #45a049; |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown("Built with β€οΈ using Streamlit and HuggingFace Transformers") |