import torch from utils.file_loader import load_txt, load_pdf, load_docx from models.qa_model import QAModel from utils.vocab import encode from utils.preprocess import tokenize checkpoint = torch.load("qa_model.pth", map_location="cpu") vocab = checkpoint["vocab"] model = QAModel(len(vocab)) model.load_state_dict(checkpoint["model_state"]) model.eval() def load_context(path): if path.endswith(".txt"): return load_txt(path) elif path.endswith(".pdf"): return load_pdf(path) elif path.endswith(".docx"): return load_docx(path) else: raise ValueError("Unsupported file format") def extract_answer(question, context): q_tokens = tokenize(question) c_tokens = tokenize(context) tokens = q_tokens + ["[SEP]"] + c_tokens encoded = encode(tokens, vocab) max_len = 300 if len(encoded) < max_len: encoded += [0] * (max_len - len(encoded)) else: encoded = encoded[:max_len] x = torch.tensor(encoded).unsqueeze(0) with torch.no_grad(): start_logits, end_logits = model(x) start = torch.argmax(start_logits, dim=1).item() end = torch.argmax(end_logits, dim=1).item() if start > end or start >= len(tokens): return "No answer found" return " ".join(tokens[start:end+1]) def main(): print("===== BiLSTM QA (Fixed) =====\n") path = input("Enter file path: ") context = load_context(path) question = input("Enter question: ") answer = extract_answer(question, context) print("\nAnswer:", answer) if __name__ == "__main__": main()