SQuAD / main.py
tnp554's picture
feat: deploy SQuAD backend with all AI models
09daf0b
import torch
from utils.file_loader import load_txt, load_pdf, load_docx
from models.qa_model import QAModel
from utils.vocab import encode
from utils.preprocess import tokenize
checkpoint = torch.load("qa_model.pth", map_location="cpu")
vocab = checkpoint["vocab"]
model = QAModel(len(vocab))
model.load_state_dict(checkpoint["model_state"])
model.eval()
def load_context(path):
if path.endswith(".txt"):
return load_txt(path)
elif path.endswith(".pdf"):
return load_pdf(path)
elif path.endswith(".docx"):
return load_docx(path)
else:
raise ValueError("Unsupported file format")
def extract_answer(question, context):
q_tokens = tokenize(question)
c_tokens = tokenize(context)
tokens = q_tokens + ["[SEP]"] + c_tokens
encoded = encode(tokens, vocab)
max_len = 300
if len(encoded) < max_len:
encoded += [0] * (max_len - len(encoded))
else:
encoded = encoded[:max_len]
x = torch.tensor(encoded).unsqueeze(0)
with torch.no_grad():
start_logits, end_logits = model(x)
start = torch.argmax(start_logits, dim=1).item()
end = torch.argmax(end_logits, dim=1).item()
if start > end or start >= len(tokens):
return "No answer found"
return " ".join(tokens[start:end+1])
def main():
print("===== BiLSTM QA (Fixed) =====\n")
path = input("Enter file path: ")
context = load_context(path)
question = input("Enter question: ")
answer = extract_answer(question, context)
print("\nAnswer:", answer)
if __name__ == "__main__":
main()