File size: 3,619 Bytes
eb2d89a
 
8a5d492
eb2d89a
 
3dda9b8
a0d291e
 
 
 
eb2d89a
a0d291e
eb2d89a
 
 
 
 
 
a0d291e
eb2d89a
a0d291e
 
eb2d89a
a0d291e
 
 
 
 
eb2d89a
a0d291e
 
 
 
 
 
 
 
 
 
eb2d89a
 
a0d291e
eb2d89a
a0d291e
eb2d89a
a0d291e
eb2d89a
 
a0d291e
 
 
 
eb2d89a
 
a0d291e
eb2d89a
 
 
a0d291e
 
eb2d89a
a0d291e
 
eb2d89a
a0d291e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb2d89a
 
 
 
 
 
 
533e155
eb2d89a
 
 
 
 
 
 
 
 
 
 
 
 
3dda9b8
 
eb2d89a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# app.py
import os, pickle
import faiss
import torch
import gradio as gr
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login

# -----------------------------
# Load Prebuilt Index
# -----------------------------
def load_prebuilt_index(index_dir="prebuilt_index"):
    index = faiss.read_index(os.path.join(index_dir, "faiss_index.bin"))
    with open(os.path.join(index_dir, "metadata.pkl"), "rb") as f:
        metadata = pickle.load(f)
    embed_model = SentenceTransformer(metadata["model_name"])
    return embed_model, index, metadata["chunks"]

embed_model, faiss_index, chunks = load_prebuilt_index()

# -----------------------------
# Load LLaMA
# -----------------------------
def load_llm():
    model_id = "meta-llama/Llama-3.2-3b-instruct"
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN not set. Run: export HF_TOKEN=your_token")
    login(hf_token)
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
    llm = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=torch.float16,
        token=hf_token
    )
    return tokenizer, llm

tokenizer, llm = load_llm()

# -----------------------------
# Answer Query
# -----------------------------
def answer_query(question):
    try:
        # ---- HyDE hypothetical answer ----
        hyde_prompt = f"[INST] Write a detailed hypothetical answer:\n{question} [/INST]"
        inputs = tokenizer(hyde_prompt, return_tensors="pt").to(llm.device)
        hyde_out = llm.generate(**inputs, max_new_tokens=200)
        hypo_answer = tokenizer.decode(hyde_out[0], skip_special_tokens=True)

        # ---- Embed hypo answer ----
        query_vec = embed_model.encode([hypo_answer])[0].astype("float32").reshape(1, -1)

        # ---- Retrieve chunks ----
        D, I = faiss_index.search(query_vec, k=5)
        relevant_chunks = [chunks[i] for i in I[0]]
        context = "\n".join(relevant_chunks)

        # ---- Final Answer ----
        final_prompt = f"""
        [INST] You are a helpful tutor. Based only on the context below, answer the question.
        If not in context, say "I could not find this in the text."
        Context:
        {context}
        Question: {question}
        Answer: [/INST]
        """
        inputs = tokenizer(final_prompt, return_tensors="pt", truncation=True).to(llm.device)
        outputs = llm.generate(**inputs, max_new_tokens=300, temperature=0.7, top_p=0.9, do_sample=True)
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

        if "Answer:" in answer:
            answer = answer.split("Answer:")[-1].strip()
        return answer
    except Exception as e:
        return f"⚠️ Error: {e}"

# -----------------------------
# Gradio UI
# -----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 📘 HyDE RAG Tutor  
        Ask questions from your textbook with **retrieval-augmented generation**.  
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            question = gr.Textbox(
                placeholder="Ask me anything from the textbook...",
                label="Your Question",
                lines=2
            )
            btn = gr.Button("✨ Get Answer", variant="primary")

        with gr.Column(scale=2):
            answer = gr.Textbox(label="Answer", interactive=False, lines=10)

    btn.click(fn=answer_query, inputs=question, outputs=answer)

demo.launch()