File size: 4,103 Bytes
f3c304f
 
5c23b9e
c3ac79c
9e8b3eb
f3c304f
5c23b9e
 
 
 
 
 
3bb3644
 
 
 
5c23b9e
c3ac79c
 
9e8b3eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bb3644
 
 
9e8b3eb
5c23b9e
 
6718a33
 
9e8b3eb
 
3bb3644
c3ac79c
 
 
9e8b3eb
 
 
c3ac79c
 
 
 
 
 
 
 
 
 
 
9e8b3eb
 
 
f3c304f
 
 
 
 
 
 
 
 
 
 
 
 
5c23b9e
 
f3c304f
9e8b3eb
 
f3c304f
9e8b3eb
 
 
 
 
 
 
 
f3c304f
9e8b3eb
 
f3c304f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e8b3eb
 
 
 
f3c304f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
from huggingface_hub import InferenceClient
#from sentence_transformers import SentenceTransformer
#import faiss
import numpy as np

# =========================
# Simple HF Embedding Retrieval (No Local Models)
# =========================

embedding_client = InferenceClient(model="sentence-transformers/all-MiniLM-L6-v2")

def embed_texts(texts):
    if isinstance(texts, str):
        texts = [texts]
    return np.array(embedding_client.feature_extraction(texts))



# =========================
# Load and Prepare Gita Text
# =========================

with open("gita.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()

def chunk_text(text, chunk_size=500, overlap=50):
    chunks = []
    start = 0
    while start < len(text):
        end = start + chunk_size
        chunks.append(text[start:end])
        start += chunk_size - overlap
    return chunks

documents = chunk_text(raw_text)
doc_embeddings = embed_texts(documents)

# Embedding model (small + free)
#embedder = SentenceTransformer("all-MiniLM-L6-v2")
#doc_embeddings = #embedder.encode(documents)
#dimension = doc_embeddings.shape[1]
#doc_embeddings = embedder.encode(documents)

def retrieve(query, top_k=4):
    query_embedding = embed_texts(query)[0]
    scores = np.dot(doc_embeddings, query_embedding)
    top_indices = np.argsort(scores)[-top_k:][::-1]
    results = [documents[i] for i in top_indices]
    return "\n\n".join(results)


# index = faiss.IndexFlatL2(dimension)
# index.add(np.array(doc_embeddings))


# def retrieve(query, top_k=4):
#     query_embedding = embedder.encode([query])
#     distances, indices = index.search(np.array(query_embedding), top_k)
#     results = [documents[i] for i in indices[0]]
#     return "\n\n".join(results)


# =========================
# RAG Chat Function
# =========================

def respond(
    message,
    history: list[dict[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    hf_token: gr.OAuthToken,
):
    """
    For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
    """
    #client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
    client = InferenceClient(token=hf_token.token)

    # Retrieve relevant Gita chunks
    context = retrieve(message)

    augmented_system_message = (
        system_message
        + "\n\nYou are RAGVeda, an expert in Indian philosophy."
        + "\nAnswer ONLY using the Bhagavad Gita context below."
        + "\nIf answer not found, say you do not know."
        + "\n\nContext:\n"
        + context
    )

    messages = [{"role": "system", "content": augmented_system_message}]
    messages.extend(history)
    messages.append({"role": "user", "content": message})

    response = ""

    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        choices = message.choices
        token = ""
        if len(choices) and choices[0].delta.content:
            token = choices[0].delta.content

        response += token
        yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
chatbot = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(
            value="You are RAGVeda, a calm and wise assistant rooted in the Bhagavad Gita.",
            label="System message",
        ),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

with gr.Blocks() as demo:
    with gr.Sidebar():
        gr.LoginButton()
    chatbot.render()


if __name__ == "__main__":
    demo.launch()