HyDE_RAG / app.py
Rishitha3's picture
Update app.py
eb2d89a verified
# app.py
import os, pickle
import faiss
import torch
import gradio as gr
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
# -----------------------------
# Load Prebuilt Index
# -----------------------------
def load_prebuilt_index(index_dir="prebuilt_index"):
index = faiss.read_index(os.path.join(index_dir, "faiss_index.bin"))
with open(os.path.join(index_dir, "metadata.pkl"), "rb") as f:
metadata = pickle.load(f)
embed_model = SentenceTransformer(metadata["model_name"])
return embed_model, index, metadata["chunks"]
embed_model, faiss_index, chunks = load_prebuilt_index()
# -----------------------------
# Load LLaMA
# -----------------------------
def load_llm():
model_id = "meta-llama/Llama-3.2-3b-instruct"
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN not set. Run: export HF_TOKEN=your_token")
login(hf_token)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
llm = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16,
token=hf_token
)
return tokenizer, llm
tokenizer, llm = load_llm()
# -----------------------------
# Answer Query
# -----------------------------
def answer_query(question):
try:
# ---- HyDE hypothetical answer ----
hyde_prompt = f"[INST] Write a detailed hypothetical answer:\n{question} [/INST]"
inputs = tokenizer(hyde_prompt, return_tensors="pt").to(llm.device)
hyde_out = llm.generate(**inputs, max_new_tokens=200)
hypo_answer = tokenizer.decode(hyde_out[0], skip_special_tokens=True)
# ---- Embed hypo answer ----
query_vec = embed_model.encode([hypo_answer])[0].astype("float32").reshape(1, -1)
# ---- Retrieve chunks ----
D, I = faiss_index.search(query_vec, k=5)
relevant_chunks = [chunks[i] for i in I[0]]
context = "\n".join(relevant_chunks)
# ---- Final Answer ----
final_prompt = f"""
[INST] You are a helpful tutor. Based only on the context below, answer the question.
If not in context, say "I could not find this in the text."
Context:
{context}
Question: {question}
Answer: [/INST]
"""
inputs = tokenizer(final_prompt, return_tensors="pt", truncation=True).to(llm.device)
outputs = llm.generate(**inputs, max_new_tokens=300, temperature=0.7, top_p=0.9, do_sample=True)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "Answer:" in answer:
answer = answer.split("Answer:")[-1].strip()
return answer
except Exception as e:
return f"⚠️ Error: {e}"
# -----------------------------
# Gradio UI
# -----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 📘 HyDE RAG Tutor
Ask questions from your textbook with **retrieval-augmented generation**.
"""
)
with gr.Row():
with gr.Column(scale=1):
question = gr.Textbox(
placeholder="Ask me anything from the textbook...",
label="Your Question",
lines=2
)
btn = gr.Button("✨ Get Answer", variant="primary")
with gr.Column(scale=2):
answer = gr.Textbox(label="Answer", interactive=False, lines=10)
btn.click(fn=answer_query, inputs=question, outputs=answer)
demo.launch()