| import os
|
| import faiss
|
| import pickle
|
| import torch
|
|
|
| from sentence_transformers import SentenceTransformer
|
| from transformers import AutoTokenizer
|
| from transformers import AutoModelForSeq2SeqLM
|
|
|
|
|
| import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
| CONFIG = {
|
|
|
| "retriever_model_path":
|
| "swathibp/BGE-base_finetuned",
|
|
|
| "generator_model_path":
|
| "swathibp/Flan_T5_merged",
|
|
|
| "save_dir":
|
| ".",
|
|
|
| "top_k": 3,
|
|
|
| "max_new_tokens": 250,
|
|
|
| "device":
|
| "cuda"
|
| if torch.cuda.is_available()
|
| else "cpu"
|
|
|
| }
|
|
|
| os.makedirs(
|
| CONFIG["save_dir"],
|
| exist_ok=True
|
| )
|
|
|
| print(
|
| "DEVICE:",
|
| CONFIG["device"]
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
| INDEX_FILE = \
|
| f"{CONFIG['save_dir']}/index.faiss"
|
|
|
| DOC_FILE = \
|
| f"{CONFIG['save_dir']}/docs.pkl"
|
|
|
|
|
| print("Loading Retriever...")
|
|
|
| retriever = SentenceTransformer(
|
| CONFIG["retriever_model_path"]
|
| )
|
|
|
| if os.path.exists(INDEX_FILE):
|
|
|
| print("Loading Stored FAISS Index")
|
|
|
| index = faiss.read_index(
|
| INDEX_FILE
|
| )
|
|
|
| with open(
|
| DOC_FILE,
|
| "rb"
|
| ) as f:
|
|
|
| documents = pickle.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
| print("Loading FLAN Generator...")
|
|
|
| tokenizer = \
|
| AutoTokenizer.from_pretrained(
|
|
|
| CONFIG[
|
| "generator_model_path"
|
| ]
|
|
|
| )
|
|
|
| generator = \
|
| AutoModelForSeq2SeqLM.from_pretrained(
|
|
|
| CONFIG[
|
| "generator_model_path"
|
| ]
|
|
|
| ).to(
|
| CONFIG["device"]
|
| )
|
|
|
| generator.eval()
|
|
|
| print("Generator Loaded")
|
|
|
|
|
|
|
|
|
|
|
|
|
| def retrieve(query):
|
|
|
| emb = \
|
| retriever.encode(
|
|
|
| [query],
|
|
|
| convert_to_numpy=True
|
|
|
| )
|
|
|
| faiss.normalize_L2(
|
| emb
|
| )
|
|
|
| scores, indices = \
|
| index.search(
|
|
|
| emb,
|
|
|
| CONFIG["top_k"]
|
|
|
| )
|
|
|
| docs = []
|
|
|
| for idx in indices[0]:
|
|
|
| docs.append(
|
| documents[idx]
|
| )
|
|
|
| return docs
|
|
|
|
|
|
|
|
|
|
|
|
|
| def generate(query):
|
|
|
| docs = retrieve(query)
|
|
|
| instruction = (
|
|
|
| "Answer ONLY using the information provided in the context. "
|
|
|
| "If the answer is not available, reply exactly: "
|
|
|
| "'Not found in the provided documents.'"
|
|
|
| )
|
|
|
| context = "\n".join(
|
| docs
|
| )
|
|
|
| prompt = f"""
|
|
|
| {instruction}
|
|
|
| Context:
|
|
|
| {context}
|
|
|
| Question:
|
|
|
| {query}
|
|
|
| Answer:
|
|
|
| """
|
|
|
| inputs = tokenizer(
|
|
|
| prompt,
|
|
|
| return_tensors="pt",
|
|
|
| truncation=True
|
|
|
| ).to(
|
| CONFIG["device"]
|
| )
|
|
|
| with torch.no_grad():
|
|
|
| outputs = \
|
| generator.generate(
|
|
|
| **inputs,
|
|
|
| max_new_tokens=
|
| CONFIG[
|
| "max_new_tokens"
|
| ],
|
|
|
| do_sample=False,
|
|
|
| early_stopping=True
|
|
|
| )
|
|
|
| answer = \
|
| tokenizer.decode(
|
|
|
| outputs[0],
|
|
|
| skip_special_tokens=True
|
|
|
| )
|
|
|
| return answer, context
|
|
|
|
|
|
|
|
|
|
|
|
|
| with gr.Blocks() as demo:
|
|
|
| gr.Markdown(
|
| "# MAHE QA System"
|
| )
|
|
|
| q = gr.Textbox(
|
|
|
| label="Question",
|
|
|
| placeholder=
|
| "Enter your MAHE question here...",
|
|
|
| lines=3,
|
|
|
| max_lines=5
|
| )
|
|
|
| ask = gr.Button(
|
| "Generate Answer"
|
| )
|
|
|
| ans = gr.Textbox(
|
|
|
| label="Answer",
|
|
|
| lines=15,
|
|
|
| max_lines=30,
|
|
|
|
|
| )
|
|
|
| ctx = gr.Textbox(
|
|
|
| label="Retrieved Context",
|
|
|
| lines=20,
|
|
|
| max_lines=40,
|
|
|
|
|
| )
|
|
|
| ask.click(
|
|
|
| generate,
|
|
|
| q,
|
|
|
| [ans, ctx]
|
|
|
| )
|
|
|
| demo.launch(
|
| share=True,
|
| debug=True
|
| ) |