#!/usr/bin/env python """ Streamlit application for Question Answering system. Optimized for deployment on Hugging Face Spaces. """ import streamlit as st import os import time import torch import pandas as pd import matplotlib.pyplot as plt from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline import json # Page configuration st.set_page_config( page_title="Question Answering System", page_icon="❓", layout="wide" ) # Constants MODELS = { "ELECTRA-small": "mrm8488/electra-small-finetuned-squadv1", "ALBERT-base-v2": "twmkn9/albert-base-v2-squad2", "DistilBERT-base": "distilbert-base-cased-distilled-squad" } # Cache for loaded models @st.cache_resource def load_model(model_name): """Load model and tokenizer with caching""" try: model_path = MODELS[model_name] qa_pipeline = pipeline("question-answering", model=model_path) return qa_pipeline except Exception as e: st.error(f"Error loading model {model_name}: {e}") return None def answer_question(qa_pipeline, question, context): """ Answer a question given a context using the QA pipeline """ if not question or not context: return None, 0, 0 # Measure inference time start_time = time.time() # Run model result = qa_pipeline(question=question, context=context) # Calculate inference time inference_time = time.time() - start_time return result["answer"], result["score"], inference_time def highlight_answer(context, answer): """Highlight the answer in the context with HTML""" if not answer or not context: return context # Find the answer in the context (case insensitive) lower_context = context.lower() lower_answer = answer.lower() if lower_answer in lower_context: start_idx = lower_context.find(lower_answer) end_idx = start_idx + len(lower_answer) # Create HTML with highlighted answer highlighted = ( context[:start_idx] + f'{context[start_idx:end_idx]}' + context[end_idx:] ) return highlighted return context def generate_comparison_chart(results_df): """Generate a comparison chart for model results""" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) # Sort models by score results_df = results_df.sort_values('score', ascending=False) # Plot scores models = results_df['model_name'] scores = results_df['score'] ax1.barh(models, scores, color='skyblue') ax1.set_xlabel('Confidence Score') ax1.set_title('Model Confidence Scores') ax1.grid(axis='x', linestyle='--', alpha=0.7) # Plot inference times inference_times = results_df['inference_time'].astype(float) ax2.barh(models, inference_times, color='salmon') ax2.set_xlabel('Inference Time (seconds)') ax2.set_title('Model Inference Times') ax2.grid(axis='x', linestyle='--', alpha=0.7) plt.tight_layout() return fig def main(): # Title and description st.title("Question Answering System") st.markdown(""" This application answers questions based on the provided context using transformer-based models fine-tuned on the SQuAD dataset. Enter a context paragraph and ask questions about it. """) # Initialize session state for storing results if 'comparison_results' not in st.session_state: st.session_state.comparison_results = None # Layout col1, col2 = st.columns([3, 1]) with col1: # Context input context = st.text_area( "Context", "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.", height=200 ) # Question input question = st.text_input("Question", "In what country is Normandy located?") # Add a separator st.markdown("---") # Results section st.subheader("Results") if st.button("Compare All Models"): progress_bar = st.progress(0) results = [] # Process each model for i, model_name in enumerate(MODELS.keys()): status_text = st.empty() status_text.text(f"Processing with {model_name}...") # Load model qa_pipeline = load_model(model_name) if qa_pipeline is not None: # Get answer answer, score, inference_time = answer_question(qa_pipeline, question, context) # Store results results.append({ "model_name": model_name, "answer": answer, "score": score, "inference_time": inference_time }) # Update progress progress_bar.progress((i + 1) / len(MODELS)) # Display results in a table if results: results_df = pd.DataFrame(results) display_df = results_df.copy() display_df["inference_time"] = display_df["inference_time"].apply(lambda x: f"{x:.4f} s") display_df["score"] = display_df["score"].apply(lambda x: f"{x:.4f}") st.table(display_df) # Save results to session state for comparison chart st.session_state.comparison_results = results_df # Show comparison chart st.subheader("Model Comparison") comparison_chart = generate_comparison_chart(results_df) st.pyplot(comparison_chart) with col2: # Model selection st.subheader("Available Models") selected_model = st.selectbox( "Select a model", list(MODELS.keys()), key="model_selector" ) # Load selected model and answer if st.button("Answer Question"): with st.spinner(f"Loading {selected_model}..."): qa_pipeline = load_model(selected_model) if qa_pipeline is not None: with st.spinner("Generating answer..."): answer, score, inference_time = answer_question(qa_pipeline, question, context) st.success("Answer generated!") st.markdown(f"**Model:** {selected_model}") st.markdown(f"**Answer:** {answer}") st.markdown(f"**Confidence:** {score:.4f}") st.markdown(f"**Inference Time:** {inference_time:.4f} seconds") # Highlight answer in context st.subheader("Answer in Context") highlighted_context = highlight_answer(context, answer) st.markdown(highlighted_context, unsafe_allow_html=True) # Advanced options with st.expander("Model Information"): st.markdown(""" **ELECTRA-small** A smaller, efficient model with good performance and speed. **ALBERT-base-v2** Parameter-efficient model with strong performance. **DistilBERT-base** Distilled BERT model that's faster while maintaining accuracy. """) if __name__ == "__main__": main()