import os import faiss import pickle import torch from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer from transformers import AutoModelForSeq2SeqLM import gradio as gr # ================================================== # CONFIG # ================================================== CONFIG = { "retriever_model_path": "swathibp/BGE-base_finetuned", "generator_model_path": "swathibp/Flan_T5_merged", "save_dir": ".", "top_k": 3, "max_new_tokens": 250, "device": "cuda" if torch.cuda.is_available() else "cpu" } os.makedirs( CONFIG["save_dir"], exist_ok=True ) print( "DEVICE:", CONFIG["device"] ) # ================================================== # BUILD / LOAD FAISS # ================================================== INDEX_FILE = \ f"{CONFIG['save_dir']}/index.faiss" DOC_FILE = \ f"{CONFIG['save_dir']}/docs.pkl" print("Loading Retriever...") retriever = SentenceTransformer( CONFIG["retriever_model_path"] ) if os.path.exists(INDEX_FILE): print("Loading Stored FAISS Index") index = faiss.read_index( INDEX_FILE ) with open( DOC_FILE, "rb" ) as f: documents = pickle.load(f) # ================================================== # LOAD GENERATOR # ================================================== print("Loading FLAN Generator...") tokenizer = \ AutoTokenizer.from_pretrained( CONFIG[ "generator_model_path" ] ) generator = \ AutoModelForSeq2SeqLM.from_pretrained( CONFIG[ "generator_model_path" ] ).to( CONFIG["device"] ) generator.eval() print("Generator Loaded") # ================================================== # RETRIEVAL # ================================================== def retrieve(query): emb = \ retriever.encode( [query], convert_to_numpy=True ) faiss.normalize_L2( emb ) scores, indices = \ index.search( emb, CONFIG["top_k"] ) docs = [] for idx in indices[0]: docs.append( documents[idx] ) return docs # ================================================== # GENERATION # ================================================== def generate(query): docs = retrieve(query) instruction = ( "Answer ONLY using the information provided in the context. " "If the answer is not available, reply exactly: " "'Not found in the provided documents.'" ) context = "\n".join( docs ) prompt = f""" {instruction} Context: {context} Question: {query} Answer: """ inputs = tokenizer( prompt, return_tensors="pt", truncation=True ).to( CONFIG["device"] ) with torch.no_grad(): outputs = \ generator.generate( **inputs, max_new_tokens= CONFIG[ "max_new_tokens" ], do_sample=False, early_stopping=True ) answer = \ tokenizer.decode( outputs[0], skip_special_tokens=True ) return answer, context # ================================================== # UI # ================================================== with gr.Blocks() as demo: gr.Markdown( "# MAHE QA System" ) q = gr.Textbox( label="Question", placeholder= "Enter your MAHE question here...", lines=3, max_lines=5 ) ask = gr.Button( "Generate Answer" ) ans = gr.Textbox( label="Answer", lines=15, max_lines=30, #show_copy_button=True ) ctx = gr.Textbox( label="Retrieved Context", lines=20, max_lines=40, #show_copy_button=True ) ask.click( generate, q, [ans, ctx] ) demo.launch( share=True, debug=True )