QA_System / app.py
swathibp's picture
Upload app.py
36de92b verified
import os
import faiss
import pickle
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
import gradio as gr
# ==================================================
# CONFIG
# ==================================================
CONFIG = {
"retriever_model_path":
"swathibp/BGE-base_finetuned",
"generator_model_path":
"swathibp/Flan_T5_merged",
"save_dir":
".",
"top_k": 3,
"max_new_tokens": 250,
"device":
"cuda"
if torch.cuda.is_available()
else "cpu"
}
os.makedirs(
CONFIG["save_dir"],
exist_ok=True
)
print(
"DEVICE:",
CONFIG["device"]
)
# ==================================================
# BUILD / LOAD FAISS
# ==================================================
INDEX_FILE = \
f"{CONFIG['save_dir']}/index.faiss"
DOC_FILE = \
f"{CONFIG['save_dir']}/docs.pkl"
print("Loading Retriever...")
retriever = SentenceTransformer(
CONFIG["retriever_model_path"]
)
if os.path.exists(INDEX_FILE):
print("Loading Stored FAISS Index")
index = faiss.read_index(
INDEX_FILE
)
with open(
DOC_FILE,
"rb"
) as f:
documents = pickle.load(f)
# ==================================================
# LOAD GENERATOR
# ==================================================
print("Loading FLAN Generator...")
tokenizer = \
AutoTokenizer.from_pretrained(
CONFIG[
"generator_model_path"
]
)
generator = \
AutoModelForSeq2SeqLM.from_pretrained(
CONFIG[
"generator_model_path"
]
).to(
CONFIG["device"]
)
generator.eval()
print("Generator Loaded")
# ==================================================
# RETRIEVAL
# ==================================================
def retrieve(query):
emb = \
retriever.encode(
[query],
convert_to_numpy=True
)
faiss.normalize_L2(
emb
)
scores, indices = \
index.search(
emb,
CONFIG["top_k"]
)
docs = []
for idx in indices[0]:
docs.append(
documents[idx]
)
return docs
# ==================================================
# GENERATION
# ==================================================
def generate(query):
docs = retrieve(query)
instruction = (
"Answer ONLY using the information provided in the context. "
"If the answer is not available, reply exactly: "
"'Not found in the provided documents.'"
)
context = "\n".join(
docs
)
prompt = f"""
{instruction}
Context:
{context}
Question:
{query}
Answer:
"""
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True
).to(
CONFIG["device"]
)
with torch.no_grad():
outputs = \
generator.generate(
**inputs,
max_new_tokens=
CONFIG[
"max_new_tokens"
],
do_sample=False,
early_stopping=True
)
answer = \
tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
return answer, context
# ==================================================
# UI
# ==================================================
with gr.Blocks() as demo:
gr.Markdown(
"# MAHE QA System"
)
q = gr.Textbox(
label="Question",
placeholder=
"Enter your MAHE question here...",
lines=3,
max_lines=5
)
ask = gr.Button(
"Generate Answer"
)
ans = gr.Textbox(
label="Answer",
lines=15,
max_lines=30,
#show_copy_button=True
)
ctx = gr.Textbox(
label="Retrieved Context",
lines=20,
max_lines=40,
#show_copy_button=True
)
ask.click(
generate,
q,
[ans, ctx]
)
demo.launch(
share=True,
debug=True
)