|
|
|
|
|
import os, pickle |
|
|
import faiss |
|
|
import torch |
|
|
import gradio as gr |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from huggingface_hub import login |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_prebuilt_index(index_dir="prebuilt_index"): |
|
|
index = faiss.read_index(os.path.join(index_dir, "faiss_index.bin")) |
|
|
with open(os.path.join(index_dir, "metadata.pkl"), "rb") as f: |
|
|
metadata = pickle.load(f) |
|
|
embed_model = SentenceTransformer(metadata["model_name"]) |
|
|
return embed_model, index, metadata["chunks"] |
|
|
|
|
|
embed_model, faiss_index, chunks = load_prebuilt_index() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_llm(): |
|
|
model_id = "meta-llama/Llama-3.2-3b-instruct" |
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
if not hf_token: |
|
|
raise ValueError("HF_TOKEN not set. Run: export HF_TOKEN=your_token") |
|
|
login(hf_token) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) |
|
|
llm = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16, |
|
|
token=hf_token |
|
|
) |
|
|
return tokenizer, llm |
|
|
|
|
|
tokenizer, llm = load_llm() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def answer_query(question): |
|
|
try: |
|
|
|
|
|
hyde_prompt = f"[INST] Write a detailed hypothetical answer:\n{question} [/INST]" |
|
|
inputs = tokenizer(hyde_prompt, return_tensors="pt").to(llm.device) |
|
|
hyde_out = llm.generate(**inputs, max_new_tokens=200) |
|
|
hypo_answer = tokenizer.decode(hyde_out[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
query_vec = embed_model.encode([hypo_answer])[0].astype("float32").reshape(1, -1) |
|
|
|
|
|
|
|
|
D, I = faiss_index.search(query_vec, k=5) |
|
|
relevant_chunks = [chunks[i] for i in I[0]] |
|
|
context = "\n".join(relevant_chunks) |
|
|
|
|
|
|
|
|
final_prompt = f""" |
|
|
[INST] You are a helpful tutor. Based only on the context below, answer the question. |
|
|
If not in context, say "I could not find this in the text." |
|
|
Context: |
|
|
{context} |
|
|
Question: {question} |
|
|
Answer: [/INST] |
|
|
""" |
|
|
inputs = tokenizer(final_prompt, return_tensors="pt", truncation=True).to(llm.device) |
|
|
outputs = llm.generate(**inputs, max_new_tokens=300, temperature=0.7, top_p=0.9, do_sample=True) |
|
|
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if "Answer:" in answer: |
|
|
answer = answer.split("Answer:")[-1].strip() |
|
|
return answer |
|
|
except Exception as e: |
|
|
return f"⚠️ Error: {e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 📘 HyDE RAG Tutor |
|
|
Ask questions from your textbook with **retrieval-augmented generation**. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
question = gr.Textbox( |
|
|
placeholder="Ask me anything from the textbook...", |
|
|
label="Your Question", |
|
|
lines=2 |
|
|
) |
|
|
btn = gr.Button("✨ Get Answer", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
answer = gr.Textbox(label="Answer", interactive=False, lines=10) |
|
|
|
|
|
btn.click(fn=answer_query, inputs=question, outputs=answer) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|