maabedmohammed's picture
Update src/streamlit_app.py
d3bbe45 verified
#!/usr/bin/env python
"""
Streamlit application for Question Answering system.
Optimized for deployment on Hugging Face Spaces.
"""
import streamlit as st
import os
import time
import torch
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
import json
# Page configuration
st.set_page_config(
page_title="Question Answering System",
page_icon="❓",
layout="wide"
)
# Constants
MODELS = {
"ELECTRA-small": "mrm8488/electra-small-finetuned-squadv1",
"ALBERT-base-v2": "twmkn9/albert-base-v2-squad2",
"DistilBERT-base": "distilbert-base-cased-distilled-squad"
}
# Cache for loaded models
@st.cache_resource
def load_model(model_name):
"""Load model and tokenizer with caching"""
try:
model_path = MODELS[model_name]
qa_pipeline = pipeline("question-answering", model=model_path)
return qa_pipeline
except Exception as e:
st.error(f"Error loading model {model_name}: {e}")
return None
def answer_question(qa_pipeline, question, context):
"""
Answer a question given a context using the QA pipeline
"""
if not question or not context:
return None, 0, 0
# Measure inference time
start_time = time.time()
# Run model
result = qa_pipeline(question=question, context=context)
# Calculate inference time
inference_time = time.time() - start_time
return result["answer"], result["score"], inference_time
def highlight_answer(context, answer):
"""Highlight the answer in the context with HTML"""
if not answer or not context:
return context
# Find the answer in the context (case insensitive)
lower_context = context.lower()
lower_answer = answer.lower()
if lower_answer in lower_context:
start_idx = lower_context.find(lower_answer)
end_idx = start_idx + len(lower_answer)
# Create HTML with highlighted answer
highlighted = (
context[:start_idx] +
f'<span style="background-color: #ffdd99; font-weight: bold;">{context[start_idx:end_idx]}</span>' +
context[end_idx:]
)
return highlighted
return context
def generate_comparison_chart(results_df):
"""Generate a comparison chart for model results"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Sort models by score
results_df = results_df.sort_values('score', ascending=False)
# Plot scores
models = results_df['model_name']
scores = results_df['score']
ax1.barh(models, scores, color='skyblue')
ax1.set_xlabel('Confidence Score')
ax1.set_title('Model Confidence Scores')
ax1.grid(axis='x', linestyle='--', alpha=0.7)
# Plot inference times
inference_times = results_df['inference_time'].astype(float)
ax2.barh(models, inference_times, color='salmon')
ax2.set_xlabel('Inference Time (seconds)')
ax2.set_title('Model Inference Times')
ax2.grid(axis='x', linestyle='--', alpha=0.7)
plt.tight_layout()
return fig
def main():
# Title and description
st.title("Question Answering System")
st.markdown("""
This application answers questions based on the provided context using transformer-based models
fine-tuned on the SQuAD dataset. Enter a context paragraph and ask questions about it.
""")
# Initialize session state for storing results
if 'comparison_results' not in st.session_state:
st.session_state.comparison_results = None
# Layout
col1, col2 = st.columns([3, 1])
with col1:
# Context input
context = st.text_area(
"Context",
"The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.",
height=200
)
# Question input
question = st.text_input("Question", "In what country is Normandy located?")
# Add a separator
st.markdown("---")
# Results section
st.subheader("Results")
if st.button("Compare All Models"):
progress_bar = st.progress(0)
results = []
# Process each model
for i, model_name in enumerate(MODELS.keys()):
status_text = st.empty()
status_text.text(f"Processing with {model_name}...")
# Load model
qa_pipeline = load_model(model_name)
if qa_pipeline is not None:
# Get answer
answer, score, inference_time = answer_question(qa_pipeline, question, context)
# Store results
results.append({
"model_name": model_name,
"answer": answer,
"score": score,
"inference_time": inference_time
})
# Update progress
progress_bar.progress((i + 1) / len(MODELS))
# Display results in a table
if results:
results_df = pd.DataFrame(results)
display_df = results_df.copy()
display_df["inference_time"] = display_df["inference_time"].apply(lambda x: f"{x:.4f} s")
display_df["score"] = display_df["score"].apply(lambda x: f"{x:.4f}")
st.table(display_df)
# Save results to session state for comparison chart
st.session_state.comparison_results = results_df
# Show comparison chart
st.subheader("Model Comparison")
comparison_chart = generate_comparison_chart(results_df)
st.pyplot(comparison_chart)
with col2:
# Model selection
st.subheader("Available Models")
selected_model = st.selectbox(
"Select a model",
list(MODELS.keys()),
key="model_selector"
)
# Load selected model and answer
if st.button("Answer Question"):
with st.spinner(f"Loading {selected_model}..."):
qa_pipeline = load_model(selected_model)
if qa_pipeline is not None:
with st.spinner("Generating answer..."):
answer, score, inference_time = answer_question(qa_pipeline, question, context)
st.success("Answer generated!")
st.markdown(f"**Model:** {selected_model}")
st.markdown(f"**Answer:** {answer}")
st.markdown(f"**Confidence:** {score:.4f}")
st.markdown(f"**Inference Time:** {inference_time:.4f} seconds")
# Highlight answer in context
st.subheader("Answer in Context")
highlighted_context = highlight_answer(context, answer)
st.markdown(highlighted_context, unsafe_allow_html=True)
# Advanced options
with st.expander("Model Information"):
st.markdown("""
**ELECTRA-small**
A smaller, efficient model with good performance and speed.
**ALBERT-base-v2**
Parameter-efficient model with strong performance.
**DistilBERT-base**
Distilled BERT model that's faster while maintaining accuracy.
""")
if __name__ == "__main__":
main()