File size: 4,792 Bytes
d80d74f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import numpy as np
import torch
from transformers import BertForQuestionAnswering, BertTokenizerFast

# ── Config ───────────────────────────────────────────────────
MODEL_DIR  = "model"
MAX_LENGTH = 384
DOC_STRIDE = 128
N_BEST     = 20
MAX_ANS_LEN = 30
DEVICE     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = BertTokenizerFast.from_pretrained(MODEL_DIR)
model     = BertForQuestionAnswering.from_pretrained(MODEL_DIR).to(DEVICE)
model.eval()
print(f"✅ Model loaded on {DEVICE}")

def answer_question(question: str, context: str) -> dict:
    inputs = tokenizer(
        question,
        context,
        max_length=MAX_LENGTH,
        truncation="only_second",
        stride=DOC_STRIDE,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
        return_tensors="pt",
    )

    offset_mapping = inputs.pop("offset_mapping")          # (n_chunks, seq_len, 2)
    sample_map     = inputs.pop("overflow_to_sample_mapping")
    sequence_ids   = [inputs.sequence_ids(i) for i in range(len(inputs["input_ids"]))]

    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)

    start_logits = outputs.start_logits.cpu().numpy()  # (n_chunks, seq_len)
    end_logits   = outputs.end_logits.cpu().numpy()

    candidates = []

    for chunk_idx in range(len(start_logits)):
        offsets  = offset_mapping[chunk_idx].numpy()
        seq_ids  = sequence_ids[chunk_idx]

        s_indexes = np.argsort(start_logits[chunk_idx])[-1:-N_BEST-1:-1]
        e_indexes = np.argsort(end_logits[chunk_idx])[-1:-N_BEST-1:-1]

        for s in s_indexes:
            for e in e_indexes:
                if seq_ids[s] != 1 or seq_ids[e] != 1:
                    continue
                if e < s or e - s + 1 > MAX_ANS_LEN:
                    continue
                candidates.append({
                    "score": float(start_logits[chunk_idx][s] + end_logits[chunk_idx][e]),
                    "text":  context[offsets[s][0]: offsets[e][1]],
                    "start": int(offsets[s][0]),
                    "end":   int(offsets[e][1]),
                })

    if not candidates:
        return {"answer": "No answer found.", "score": -999, "start": -1, "end": -1}

    best = max(candidates, key=lambda x: x["score"])
    return {
        "answer": best["text"],
        "score":  round(best["score"], 4),
        "start":  best["start"],
        "end":    best["end"],
    }


def ask(question: str, context: str):
    result = answer_question(question, context)
    print(f"❓ Question: {question}")
    print(f"💬 Answer  : {result['answer']}")
    print(f"📊 Score   : {result['score']}")
    print(f"📍 Position: Char {result['start']}{result['end']}")
    print("-" * 60)



ctx1 = """
The Amazon rainforest, also known as Amazonia, is a moist broadleaf 
tropical rainforest in the Amazon biome that covers most of the Amazon 
basin of South America. This basin encompasses 7,000,000 km² of which 
5,500,000 km² are covered by the rainforest. The majority of the forest 
is contained within Brazil, with 60% of the rainforest.
"""
ask("How much of the Amazon rainforest is in Brazil?", ctx1)

ctx2 = """
The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars 
in Paris, France. It was constructed from 1887 to 1889 as the centerpiece 
of the 1889 World's Fair. The tower is 330 metres tall and is the tallest 
structure in Paris.
"""
ask("When was the Eiffel Tower built?", ctx2)

ctx3 = """
Python is a high-level, general-purpose programming language. Its design 
philosophy emphasizes code readability with the use of significant indentation.
Python is dynamically typed and garbage-collected. It supports multiple 
programming paradigms, including structured, object-oriented and functional 
programming. It was created by Guido van Rossum and first released in 1991.
Python consistently ranks as one of the most popular programming languages.
It is widely used in data science, machine learning, web development, and 
automation. The Python Package Index (PyPI) hosts hundreds of thousands of 
third-party modules. The standard library is very extensive, offering tools 
suited to many tasks.
""" * 3

ask("When was Python first released?", ctx3)

print("\n" + "=" * 60)
print("🎮 Interactive mode – stop with 'quit'")
print("=" * 60)

context_interactive = input("📄 Input context:\n> ").strip()
while True:
    q = input("\n❓ Question (or type 'quit'): ").strip()
    if q.lower() == "quit":
        print("👋 Bye.")
        break
    if not q:
        continue
    ask(q, context_interactive)