File size: 4,792 Bytes
d80d74f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import numpy as np
import torch
from transformers import BertForQuestionAnswering, BertTokenizerFast
# ── Config ───────────────────────────────────────────────────
MODEL_DIR = "model"
MAX_LENGTH = 384
DOC_STRIDE = 128
N_BEST = 20
MAX_ANS_LEN = 30
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizerFast.from_pretrained(MODEL_DIR)
model = BertForQuestionAnswering.from_pretrained(MODEL_DIR).to(DEVICE)
model.eval()
print(f"✅ Model loaded on {DEVICE}")
def answer_question(question: str, context: str) -> dict:
inputs = tokenizer(
question,
context,
max_length=MAX_LENGTH,
truncation="only_second",
stride=DOC_STRIDE,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
return_tensors="pt",
)
offset_mapping = inputs.pop("offset_mapping") # (n_chunks, seq_len, 2)
sample_map = inputs.pop("overflow_to_sample_mapping")
sequence_ids = [inputs.sequence_ids(i) for i in range(len(inputs["input_ids"]))]
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
start_logits = outputs.start_logits.cpu().numpy() # (n_chunks, seq_len)
end_logits = outputs.end_logits.cpu().numpy()
candidates = []
for chunk_idx in range(len(start_logits)):
offsets = offset_mapping[chunk_idx].numpy()
seq_ids = sequence_ids[chunk_idx]
s_indexes = np.argsort(start_logits[chunk_idx])[-1:-N_BEST-1:-1]
e_indexes = np.argsort(end_logits[chunk_idx])[-1:-N_BEST-1:-1]
for s in s_indexes:
for e in e_indexes:
if seq_ids[s] != 1 or seq_ids[e] != 1:
continue
if e < s or e - s + 1 > MAX_ANS_LEN:
continue
candidates.append({
"score": float(start_logits[chunk_idx][s] + end_logits[chunk_idx][e]),
"text": context[offsets[s][0]: offsets[e][1]],
"start": int(offsets[s][0]),
"end": int(offsets[e][1]),
})
if not candidates:
return {"answer": "No answer found.", "score": -999, "start": -1, "end": -1}
best = max(candidates, key=lambda x: x["score"])
return {
"answer": best["text"],
"score": round(best["score"], 4),
"start": best["start"],
"end": best["end"],
}
def ask(question: str, context: str):
result = answer_question(question, context)
print(f"❓ Question: {question}")
print(f"💬 Answer : {result['answer']}")
print(f"📊 Score : {result['score']}")
print(f"📍 Position: Char {result['start']}–{result['end']}")
print("-" * 60)
ctx1 = """
The Amazon rainforest, also known as Amazonia, is a moist broadleaf
tropical rainforest in the Amazon biome that covers most of the Amazon
basin of South America. This basin encompasses 7,000,000 km² of which
5,500,000 km² are covered by the rainforest. The majority of the forest
is contained within Brazil, with 60% of the rainforest.
"""
ask("How much of the Amazon rainforest is in Brazil?", ctx1)
ctx2 = """
The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars
in Paris, France. It was constructed from 1887 to 1889 as the centerpiece
of the 1889 World's Fair. The tower is 330 metres tall and is the tallest
structure in Paris.
"""
ask("When was the Eiffel Tower built?", ctx2)
ctx3 = """
Python is a high-level, general-purpose programming language. Its design
philosophy emphasizes code readability with the use of significant indentation.
Python is dynamically typed and garbage-collected. It supports multiple
programming paradigms, including structured, object-oriented and functional
programming. It was created by Guido van Rossum and first released in 1991.
Python consistently ranks as one of the most popular programming languages.
It is widely used in data science, machine learning, web development, and
automation. The Python Package Index (PyPI) hosts hundreds of thousands of
third-party modules. The standard library is very extensive, offering tools
suited to many tasks.
""" * 3
ask("When was Python first released?", ctx3)
print("\n" + "=" * 60)
print("🎮 Interactive mode – stop with 'quit'")
print("=" * 60)
context_interactive = input("📄 Input context:\n> ").strip()
while True:
q = input("\n❓ Question (or type 'quit'): ").strip()
if q.lower() == "quit":
print("👋 Bye.")
break
if not q:
continue
ask(q, context_interactive) |