File size: 8,345 Bytes
d3bbe45 a0838e4 d3bbe45 a0838e4 d3bbe45 a0838e4 d3bbe45 a0838e4 d3bbe45 a0838e4 d3bbe45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
#!/usr/bin/env python
"""
Streamlit application for Question Answering system.
Optimized for deployment on Hugging Face Spaces.
"""
import streamlit as st
import os
import time
import torch
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
import json
# Page configuration
st.set_page_config(
page_title="Question Answering System",
page_icon="❓",
layout="wide"
)
# Constants
MODELS = {
"ELECTRA-small": "mrm8488/electra-small-finetuned-squadv1",
"ALBERT-base-v2": "twmkn9/albert-base-v2-squad2",
"DistilBERT-base": "distilbert-base-cased-distilled-squad"
}
# Cache for loaded models
@st.cache_resource
def load_model(model_name):
"""Load model and tokenizer with caching"""
try:
model_path = MODELS[model_name]
qa_pipeline = pipeline("question-answering", model=model_path)
return qa_pipeline
except Exception as e:
st.error(f"Error loading model {model_name}: {e}")
return None
def answer_question(qa_pipeline, question, context):
"""
Answer a question given a context using the QA pipeline
"""
if not question or not context:
return None, 0, 0
# Measure inference time
start_time = time.time()
# Run model
result = qa_pipeline(question=question, context=context)
# Calculate inference time
inference_time = time.time() - start_time
return result["answer"], result["score"], inference_time
def highlight_answer(context, answer):
"""Highlight the answer in the context with HTML"""
if not answer or not context:
return context
# Find the answer in the context (case insensitive)
lower_context = context.lower()
lower_answer = answer.lower()
if lower_answer in lower_context:
start_idx = lower_context.find(lower_answer)
end_idx = start_idx + len(lower_answer)
# Create HTML with highlighted answer
highlighted = (
context[:start_idx] +
f'<span style="background-color: #ffdd99; font-weight: bold;">{context[start_idx:end_idx]}</span>' +
context[end_idx:]
)
return highlighted
return context
def generate_comparison_chart(results_df):
"""Generate a comparison chart for model results"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Sort models by score
results_df = results_df.sort_values('score', ascending=False)
# Plot scores
models = results_df['model_name']
scores = results_df['score']
ax1.barh(models, scores, color='skyblue')
ax1.set_xlabel('Confidence Score')
ax1.set_title('Model Confidence Scores')
ax1.grid(axis='x', linestyle='--', alpha=0.7)
# Plot inference times
inference_times = results_df['inference_time'].astype(float)
ax2.barh(models, inference_times, color='salmon')
ax2.set_xlabel('Inference Time (seconds)')
ax2.set_title('Model Inference Times')
ax2.grid(axis='x', linestyle='--', alpha=0.7)
plt.tight_layout()
return fig
def main():
# Title and description
st.title("Question Answering System")
st.markdown("""
This application answers questions based on the provided context using transformer-based models
fine-tuned on the SQuAD dataset. Enter a context paragraph and ask questions about it.
""")
# Initialize session state for storing results
if 'comparison_results' not in st.session_state:
st.session_state.comparison_results = None
# Layout
col1, col2 = st.columns([3, 1])
with col1:
# Context input
context = st.text_area(
"Context",
"The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.",
height=200
)
# Question input
question = st.text_input("Question", "In what country is Normandy located?")
# Add a separator
st.markdown("---")
# Results section
st.subheader("Results")
if st.button("Compare All Models"):
progress_bar = st.progress(0)
results = []
# Process each model
for i, model_name in enumerate(MODELS.keys()):
status_text = st.empty()
status_text.text(f"Processing with {model_name}...")
# Load model
qa_pipeline = load_model(model_name)
if qa_pipeline is not None:
# Get answer
answer, score, inference_time = answer_question(qa_pipeline, question, context)
# Store results
results.append({
"model_name": model_name,
"answer": answer,
"score": score,
"inference_time": inference_time
})
# Update progress
progress_bar.progress((i + 1) / len(MODELS))
# Display results in a table
if results:
results_df = pd.DataFrame(results)
display_df = results_df.copy()
display_df["inference_time"] = display_df["inference_time"].apply(lambda x: f"{x:.4f} s")
display_df["score"] = display_df["score"].apply(lambda x: f"{x:.4f}")
st.table(display_df)
# Save results to session state for comparison chart
st.session_state.comparison_results = results_df
# Show comparison chart
st.subheader("Model Comparison")
comparison_chart = generate_comparison_chart(results_df)
st.pyplot(comparison_chart)
with col2:
# Model selection
st.subheader("Available Models")
selected_model = st.selectbox(
"Select a model",
list(MODELS.keys()),
key="model_selector"
)
# Load selected model and answer
if st.button("Answer Question"):
with st.spinner(f"Loading {selected_model}..."):
qa_pipeline = load_model(selected_model)
if qa_pipeline is not None:
with st.spinner("Generating answer..."):
answer, score, inference_time = answer_question(qa_pipeline, question, context)
st.success("Answer generated!")
st.markdown(f"**Model:** {selected_model}")
st.markdown(f"**Answer:** {answer}")
st.markdown(f"**Confidence:** {score:.4f}")
st.markdown(f"**Inference Time:** {inference_time:.4f} seconds")
# Highlight answer in context
st.subheader("Answer in Context")
highlighted_context = highlight_answer(context, answer)
st.markdown(highlighted_context, unsafe_allow_html=True)
# Advanced options
with st.expander("Model Information"):
st.markdown("""
**ELECTRA-small**
A smaller, efficient model with good performance and speed.
**ALBERT-base-v2**
Parameter-efficient model with strong performance.
**DistilBERT-base**
Distilled BERT model that's faster while maintaining accuracy.
""")
if __name__ == "__main__":
main() |