# app.py import os, pickle import faiss import torch import gradio as gr from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import login # ----------------------------- # Load Prebuilt Index # ----------------------------- def load_prebuilt_index(index_dir="prebuilt_index"): index = faiss.read_index(os.path.join(index_dir, "faiss_index.bin")) with open(os.path.join(index_dir, "metadata.pkl"), "rb") as f: metadata = pickle.load(f) embed_model = SentenceTransformer(metadata["model_name"]) return embed_model, index, metadata["chunks"] embed_model, faiss_index, chunks = load_prebuilt_index() # ----------------------------- # Load LLaMA # ----------------------------- def load_llm(): model_id = "meta-llama/Llama-3.2-3b-instruct" hf_token = os.getenv("HF_TOKEN") if not hf_token: raise ValueError("HF_TOKEN not set. Run: export HF_TOKEN=your_token") login(hf_token) tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) llm = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.float16, token=hf_token ) return tokenizer, llm tokenizer, llm = load_llm() # ----------------------------- # Answer Query # ----------------------------- def answer_query(question): try: # ---- HyDE hypothetical answer ---- hyde_prompt = f"[INST] Write a detailed hypothetical answer:\n{question} [/INST]" inputs = tokenizer(hyde_prompt, return_tensors="pt").to(llm.device) hyde_out = llm.generate(**inputs, max_new_tokens=200) hypo_answer = tokenizer.decode(hyde_out[0], skip_special_tokens=True) # ---- Embed hypo answer ---- query_vec = embed_model.encode([hypo_answer])[0].astype("float32").reshape(1, -1) # ---- Retrieve chunks ---- D, I = faiss_index.search(query_vec, k=5) relevant_chunks = [chunks[i] for i in I[0]] context = "\n".join(relevant_chunks) # ---- Final Answer ---- final_prompt = f""" [INST] You are a helpful tutor. Based only on the context below, answer the question. If not in context, say "I could not find this in the text." Context: {context} Question: {question} Answer: [/INST] """ inputs = tokenizer(final_prompt, return_tensors="pt", truncation=True).to(llm.device) outputs = llm.generate(**inputs, max_new_tokens=300, temperature=0.7, top_p=0.9, do_sample=True) answer = tokenizer.decode(outputs[0], skip_special_tokens=True) if "Answer:" in answer: answer = answer.split("Answer:")[-1].strip() return answer except Exception as e: return f"⚠️ Error: {e}" # ----------------------------- # Gradio UI # ----------------------------- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 📘 HyDE RAG Tutor Ask questions from your textbook with **retrieval-augmented generation**. """ ) with gr.Row(): with gr.Column(scale=1): question = gr.Textbox( placeholder="Ask me anything from the textbook...", label="Your Question", lines=2 ) btn = gr.Button("✨ Get Answer", variant="primary") with gr.Column(scale=2): answer = gr.Textbox(label="Answer", interactive=False, lines=10) btn.click(fn=answer_query, inputs=question, outputs=answer) demo.launch()