File size: 3,619 Bytes
eb2d89a 8a5d492 eb2d89a 3dda9b8 a0d291e eb2d89a a0d291e eb2d89a a0d291e eb2d89a a0d291e eb2d89a a0d291e eb2d89a a0d291e eb2d89a a0d291e eb2d89a a0d291e eb2d89a a0d291e eb2d89a a0d291e eb2d89a a0d291e eb2d89a a0d291e eb2d89a a0d291e eb2d89a a0d291e eb2d89a 533e155 eb2d89a 3dda9b8 eb2d89a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | # app.py
import os, pickle
import faiss
import torch
import gradio as gr
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
# -----------------------------
# Load Prebuilt Index
# -----------------------------
def load_prebuilt_index(index_dir="prebuilt_index"):
index = faiss.read_index(os.path.join(index_dir, "faiss_index.bin"))
with open(os.path.join(index_dir, "metadata.pkl"), "rb") as f:
metadata = pickle.load(f)
embed_model = SentenceTransformer(metadata["model_name"])
return embed_model, index, metadata["chunks"]
embed_model, faiss_index, chunks = load_prebuilt_index()
# -----------------------------
# Load LLaMA
# -----------------------------
def load_llm():
model_id = "meta-llama/Llama-3.2-3b-instruct"
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN not set. Run: export HF_TOKEN=your_token")
login(hf_token)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
llm = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16,
token=hf_token
)
return tokenizer, llm
tokenizer, llm = load_llm()
# -----------------------------
# Answer Query
# -----------------------------
def answer_query(question):
try:
# ---- HyDE hypothetical answer ----
hyde_prompt = f"[INST] Write a detailed hypothetical answer:\n{question} [/INST]"
inputs = tokenizer(hyde_prompt, return_tensors="pt").to(llm.device)
hyde_out = llm.generate(**inputs, max_new_tokens=200)
hypo_answer = tokenizer.decode(hyde_out[0], skip_special_tokens=True)
# ---- Embed hypo answer ----
query_vec = embed_model.encode([hypo_answer])[0].astype("float32").reshape(1, -1)
# ---- Retrieve chunks ----
D, I = faiss_index.search(query_vec, k=5)
relevant_chunks = [chunks[i] for i in I[0]]
context = "\n".join(relevant_chunks)
# ---- Final Answer ----
final_prompt = f"""
[INST] You are a helpful tutor. Based only on the context below, answer the question.
If not in context, say "I could not find this in the text."
Context:
{context}
Question: {question}
Answer: [/INST]
"""
inputs = tokenizer(final_prompt, return_tensors="pt", truncation=True).to(llm.device)
outputs = llm.generate(**inputs, max_new_tokens=300, temperature=0.7, top_p=0.9, do_sample=True)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "Answer:" in answer:
answer = answer.split("Answer:")[-1].strip()
return answer
except Exception as e:
return f"⚠️ Error: {e}"
# -----------------------------
# Gradio UI
# -----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 📘 HyDE RAG Tutor
Ask questions from your textbook with **retrieval-augmented generation**.
"""
)
with gr.Row():
with gr.Column(scale=1):
question = gr.Textbox(
placeholder="Ask me anything from the textbook...",
label="Your Question",
lines=2
)
btn = gr.Button("✨ Get Answer", variant="primary")
with gr.Column(scale=2):
answer = gr.Textbox(label="Answer", interactive=False, lines=10)
btn.click(fn=answer_query, inputs=question, outputs=answer)
demo.launch()
|