| import torch |
| from utils.file_loader import load_txt, load_pdf, load_docx |
| from models.qa_model import QAModel |
| from utils.vocab import encode |
| from utils.preprocess import tokenize |
|
|
|
|
| checkpoint = torch.load("qa_model.pth", map_location="cpu") |
| vocab = checkpoint["vocab"] |
|
|
| model = QAModel(len(vocab)) |
| model.load_state_dict(checkpoint["model_state"]) |
| model.eval() |
|
|
|
|
| def load_context(path): |
| if path.endswith(".txt"): |
| return load_txt(path) |
| elif path.endswith(".pdf"): |
| return load_pdf(path) |
| elif path.endswith(".docx"): |
| return load_docx(path) |
| else: |
| raise ValueError("Unsupported file format") |
|
|
|
|
| def extract_answer(question, context): |
| q_tokens = tokenize(question) |
| c_tokens = tokenize(context) |
|
|
| tokens = q_tokens + ["[SEP]"] + c_tokens |
| encoded = encode(tokens, vocab) |
|
|
| max_len = 300 |
| if len(encoded) < max_len: |
| encoded += [0] * (max_len - len(encoded)) |
| else: |
| encoded = encoded[:max_len] |
|
|
| x = torch.tensor(encoded).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| start_logits, end_logits = model(x) |
|
|
| start = torch.argmax(start_logits, dim=1).item() |
| end = torch.argmax(end_logits, dim=1).item() |
|
|
| if start > end or start >= len(tokens): |
| return "No answer found" |
|
|
| return " ".join(tokens[start:end+1]) |
|
|
|
|
| def main(): |
| print("===== BiLSTM QA (Fixed) =====\n") |
|
|
| path = input("Enter file path: ") |
| context = load_context(path) |
|
|
| question = input("Enter question: ") |
|
|
| answer = extract_answer(question, context) |
|
|
| print("\nAnswer:", answer) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|