File size: 4,103 Bytes
f3c304f 5c23b9e c3ac79c 9e8b3eb f3c304f 5c23b9e 3bb3644 5c23b9e c3ac79c 9e8b3eb 3bb3644 9e8b3eb 5c23b9e 6718a33 9e8b3eb 3bb3644 c3ac79c 9e8b3eb c3ac79c 9e8b3eb f3c304f 5c23b9e f3c304f 9e8b3eb f3c304f 9e8b3eb f3c304f 9e8b3eb f3c304f 9e8b3eb f3c304f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import gradio as gr
from huggingface_hub import InferenceClient
#from sentence_transformers import SentenceTransformer
#import faiss
import numpy as np
# =========================
# Simple HF Embedding Retrieval (No Local Models)
# =========================
embedding_client = InferenceClient(model="sentence-transformers/all-MiniLM-L6-v2")
def embed_texts(texts):
if isinstance(texts, str):
texts = [texts]
return np.array(embedding_client.feature_extraction(texts))
# =========================
# Load and Prepare Gita Text
# =========================
with open("gita.txt", "r", encoding="utf-8") as f:
raw_text = f.read()
def chunk_text(text, chunk_size=500, overlap=50):
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunks.append(text[start:end])
start += chunk_size - overlap
return chunks
documents = chunk_text(raw_text)
doc_embeddings = embed_texts(documents)
# Embedding model (small + free)
#embedder = SentenceTransformer("all-MiniLM-L6-v2")
#doc_embeddings = #embedder.encode(documents)
#dimension = doc_embeddings.shape[1]
#doc_embeddings = embedder.encode(documents)
def retrieve(query, top_k=4):
query_embedding = embed_texts(query)[0]
scores = np.dot(doc_embeddings, query_embedding)
top_indices = np.argsort(scores)[-top_k:][::-1]
results = [documents[i] for i in top_indices]
return "\n\n".join(results)
# index = faiss.IndexFlatL2(dimension)
# index.add(np.array(doc_embeddings))
# def retrieve(query, top_k=4):
# query_embedding = embedder.encode([query])
# distances, indices = index.search(np.array(query_embedding), top_k)
# results = [documents[i] for i in indices[0]]
# return "\n\n".join(results)
# =========================
# RAG Chat Function
# =========================
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
temperature,
top_p,
hf_token: gr.OAuthToken,
):
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
#client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
client = InferenceClient(token=hf_token.token)
# Retrieve relevant Gita chunks
context = retrieve(message)
augmented_system_message = (
system_message
+ "\n\nYou are RAGVeda, an expert in Indian philosophy."
+ "\nAnswer ONLY using the Bhagavad Gita context below."
+ "\nIf answer not found, say you do not know."
+ "\n\nContext:\n"
+ context
)
messages = [{"role": "system", "content": augmented_system_message}]
messages.extend(history)
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
choices = message.choices
token = ""
if len(choices) and choices[0].delta.content:
token = choices[0].delta.content
response += token
yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(
value="You are RAGVeda, a calm and wise assistant rooted in the Bhagavad Gita.",
label="System message",
),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
with gr.Blocks() as demo:
with gr.Sidebar():
gr.LoginButton()
chatbot.render()
if __name__ == "__main__":
demo.launch()
|