File size: 8,345 Bytes
d3bbe45
 
 
 
 
 
a0838e4
d3bbe45
 
 
 
 
 
 
a0838e4
d3bbe45
 
 
 
 
 
a0838e4
d3bbe45
 
 
 
 
 
a0838e4
d3bbe45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0838e4
d3bbe45
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
#!/usr/bin/env python
"""
Streamlit application for Question Answering system.
Optimized for deployment on Hugging Face Spaces.
"""

import streamlit as st
import os
import time
import torch
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
import json

# Page configuration
st.set_page_config(
    page_title="Question Answering System",
    page_icon="❓",
    layout="wide"
)

# Constants
MODELS = {
    "ELECTRA-small": "mrm8488/electra-small-finetuned-squadv1",
    "ALBERT-base-v2": "twmkn9/albert-base-v2-squad2",
    "DistilBERT-base": "distilbert-base-cased-distilled-squad"
}

# Cache for loaded models
@st.cache_resource
def load_model(model_name):
    """Load model and tokenizer with caching"""
    try:
        model_path = MODELS[model_name]
        qa_pipeline = pipeline("question-answering", model=model_path)
        return qa_pipeline
    except Exception as e:
        st.error(f"Error loading model {model_name}: {e}")
        return None

def answer_question(qa_pipeline, question, context):
    """
    Answer a question given a context using the QA pipeline
    """
    if not question or not context:
        return None, 0, 0
    
    # Measure inference time
    start_time = time.time()
    
    # Run model
    result = qa_pipeline(question=question, context=context)
    
    # Calculate inference time
    inference_time = time.time() - start_time
    
    return result["answer"], result["score"], inference_time

def highlight_answer(context, answer):
    """Highlight the answer in the context with HTML"""
    if not answer or not context:
        return context
    
    # Find the answer in the context (case insensitive)
    lower_context = context.lower()
    lower_answer = answer.lower()
    
    if lower_answer in lower_context:
        start_idx = lower_context.find(lower_answer)
        end_idx = start_idx + len(lower_answer)
        
        # Create HTML with highlighted answer
        highlighted = (
            context[:start_idx] +
            f'<span style="background-color: #ffdd99; font-weight: bold;">{context[start_idx:end_idx]}</span>' +
            context[end_idx:]
        )
        return highlighted
    
    return context

def generate_comparison_chart(results_df):
    """Generate a comparison chart for model results"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Sort models by score
    results_df = results_df.sort_values('score', ascending=False)
    
    # Plot scores
    models = results_df['model_name']
    scores = results_df['score']
    ax1.barh(models, scores, color='skyblue')
    ax1.set_xlabel('Confidence Score')
    ax1.set_title('Model Confidence Scores')
    ax1.grid(axis='x', linestyle='--', alpha=0.7)
    
    # Plot inference times
    inference_times = results_df['inference_time'].astype(float)
    ax2.barh(models, inference_times, color='salmon')
    ax2.set_xlabel('Inference Time (seconds)')
    ax2.set_title('Model Inference Times')
    ax2.grid(axis='x', linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    return fig

def main():
    # Title and description
    st.title("Question Answering System")
    st.markdown("""
    This application answers questions based on the provided context using transformer-based models 
    fine-tuned on the SQuAD dataset. Enter a context paragraph and ask questions about it.
    """)
    
    # Initialize session state for storing results
    if 'comparison_results' not in st.session_state:
        st.session_state.comparison_results = None
    
    # Layout
    col1, col2 = st.columns([3, 1])
    
    with col1:
        # Context input
        context = st.text_area(
            "Context",
            "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.",
            height=200
        )
        
        # Question input
        question = st.text_input("Question", "In what country is Normandy located?")
        
        # Add a separator
        st.markdown("---")
        
        # Results section
        st.subheader("Results")
        
        if st.button("Compare All Models"):
            progress_bar = st.progress(0)
            results = []
            
            # Process each model
            for i, model_name in enumerate(MODELS.keys()):
                status_text = st.empty()
                status_text.text(f"Processing with {model_name}...")
                
                # Load model
                qa_pipeline = load_model(model_name)
                if qa_pipeline is not None:
                    # Get answer
                    answer, score, inference_time = answer_question(qa_pipeline, question, context)
                    
                    # Store results
                    results.append({
                        "model_name": model_name,
                        "answer": answer,
                        "score": score,
                        "inference_time": inference_time
                    })
                
                # Update progress
                progress_bar.progress((i + 1) / len(MODELS))
            
            # Display results in a table
            if results:
                results_df = pd.DataFrame(results)
                display_df = results_df.copy()
                display_df["inference_time"] = display_df["inference_time"].apply(lambda x: f"{x:.4f} s")
                display_df["score"] = display_df["score"].apply(lambda x: f"{x:.4f}")
                st.table(display_df)
                
                # Save results to session state for comparison chart
                st.session_state.comparison_results = results_df
                
                # Show comparison chart
                st.subheader("Model Comparison")
                comparison_chart = generate_comparison_chart(results_df)
                st.pyplot(comparison_chart)
    
    with col2:
        # Model selection
        st.subheader("Available Models")
        
        selected_model = st.selectbox(
            "Select a model",
            list(MODELS.keys()),
            key="model_selector"
        )
        
        # Load selected model and answer
        if st.button("Answer Question"):
            with st.spinner(f"Loading {selected_model}..."):
                qa_pipeline = load_model(selected_model)
            
            if qa_pipeline is not None:
                with st.spinner("Generating answer..."):
                    answer, score, inference_time = answer_question(qa_pipeline, question, context)
                
                st.success("Answer generated!")
                st.markdown(f"**Model:** {selected_model}")
                st.markdown(f"**Answer:** {answer}")
                st.markdown(f"**Confidence:** {score:.4f}")
                st.markdown(f"**Inference Time:** {inference_time:.4f} seconds")
                
                # Highlight answer in context
                st.subheader("Answer in Context")
                highlighted_context = highlight_answer(context, answer)
                st.markdown(highlighted_context, unsafe_allow_html=True)
        
        # Advanced options
        with st.expander("Model Information"):
            st.markdown("""
            **ELECTRA-small**  
            A smaller, efficient model with good performance and speed.
            
            **ALBERT-base-v2**  
            Parameter-efficient model with strong performance.
            
            **DistilBERT-base**  
            Distilled BERT model that's faster while maintaining accuracy.
            """)

if __name__ == "__main__":
    main()