File size: 2,118 Bytes
461f64f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35cedd1
 
 
461f64f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e6ba65
461f64f
 
 
6c5cdb8
b8004c8
 
461f64f
 
 
5ffe74e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
Question Answering System trained on SQuAD 2.0
"""

import gradio as gr
import sys
from pathlib import Path

# Add parent directory to Python path so as to load 'src' module
current_dir = Path(__file__).parent
sys.path.insert(0, str(current_dir))

from src.models.bert_based_model import BertBasedQAModel
from src.config.model_configs import OriginalBertQAConfig
from src.etl.types import QAExample

model = BertBasedQAModel.load_from_experiment(
    experiment_dir=Path("checkpoint"), config_class=OriginalBertQAConfig, device="cpu"
)


def answer_question(context: str, question: str) -> str:
    """Process QA request and return answer."""
    if not context.strip():
        return "Please provide context text."
    if not question.strip():
        return "Please provide a question."

    try:
        example = QAExample(
            question_id="demo",
            title="Demo",
            question=question.strip(),
            context=context.strip(),
            answer_texts=[],
            answer_starts=[],
            # TODO - treat this more systematically accounting for inference;
            # setting is_impossible to True since no ground truth is available for an unknown Q
            is_impossible=True,
        )

        predictions = model.predict({"demo": example})
        answer = predictions["demo"].predicted_answer

        return answer if answer else "No answer found."

    except Exception as e:
        return f"Error: {str(e)}"


demo = gr.Interface(
    fn=answer_question,
    inputs=[
        gr.Textbox(lines=8, placeholder="Enter context paragraph...", label="Context"),
        gr.Textbox(placeholder="Enter your question...", label="Question"),
    ],
    outputs=gr.Textbox(label="Answer", show_copy_button=True, lines=4),
    title="SQuAD 2.0 Question Answering",
    description="BERT-base model fine-tuned on SQuAD 2.0 dataset",
    allow_flagging="never",
    deep_link=False,  # hides the "Share via Link" button
    theme="earneleh/paris",
    # theme=gr.themes.Default(primary_hue="indigo", neutral_hue="gray"),
)

if __name__ == "__main__":
    demo.launch()