Baction's picture
Upload Agentic RAG demo (app, requirements, README)
fee9a33 verified
"""
Gradio demo that exposes your agentic QA pipeline (uses smolagents CodeAgent + a BM25 retriever).
Intended for deployment to Hugging Face Spaces. Set HF_TOKEN in Space secrets or environment.
"""
import os
import traceback
import gradio as gr
# Basic ML / NLP libs used by your pipeline
import datasets
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
# smolagents agent pieces
from smolagents import Tool, InferenceClientModel, CodeAgent
# -------------------------
# Document preparation
# -------------------------
def prepare_knowledge_base(cache_dir="/tmp/hf_kb_cache"):
"""
Download and prepare the HF docs dataset, filter to transformers docs,
chunk into smaller documents and return the processed doc list.
This function caches results across runs (simple file-check).
"""
import os
import pickle
cache_path = os.path.join(cache_dir, "docs_processed.pkl")
os.makedirs(cache_dir, exist_ok=True)
# If cached, load and return
if os.path.exists(cache_path):
try:
with open(cache_path, "rb") as f:
docs_processed = pickle.load(f)
return docs_processed
except Exception:
# fall through to re-create cache
pass
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
# Keep only transformers docs (same filter as your original snippet)
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
source_docs = [
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
for doc in knowledge_base
]
# Split into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""],
)
docs_processed = text_splitter.split_documents(source_docs)
import pickle
with open(cache_path, "wb") as f:
pickle.dump(docs_processed, f)
return docs_processed
# -------------------------
# Retriever tool for agent
# -------------------------
class RetrieverTool(Tool):
name = "retriever"
description = "Uses BM25 retrieval over transformers docs to fetch context relevant to a question."
inputs = {
"query": {
"type": "string",
"description": "A short query describing the information to retrieve (affirmative form).",
}
}
output_type = "string"
def __init__(self, docs, k=8, **kwargs):
super().__init__(**kwargs)
# Build a BM25 retriever from the processed docs
self.retriever = BM25Retriever.from_documents(docs, k=k)
def forward(self, query: str) -> str:
assert isinstance(query, str), "query must be a string"
docs = self.retriever.invoke(query)
formatted = "\nRetrieved documents:\n" + "".join(
[
f"\n\n===== Document {i} =====\n{doc.page_content}"
for i, doc in enumerate(docs)
]
)
return formatted
# -------------------------
# Agent initialization
# -------------------------
# Prepare docs
DOCS = prepare_knowledge_base()
# Initialize tool instance
retriever_tool = RetrieverTool(DOCS)
# NOTE: On HF Spaces you can set environment variable HF_TOKEN in the UI (Settings -> Secrets).
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN is None:
print("Warning: HF_TOKEN not set. If your chosen model requires authentication, set HF_TOKEN in environment/secrets.")
model = InferenceClientModel() # default model; you can set model_id arg if needed
agent = CodeAgent(
tools=[retriever_tool],
model=model,
max_steps=4,
verbosity_level=1,
)
# -------------------------
# Gradio interface
# -------------------------
def run_agent(question: str):
"""Run the agent and return the final answer (or a helpful error)."""
if not question or question.strip() == "":
return "Please enter a question."
# If agent couldn't be created, return fallback info
if agent is None:
return "Agent not initialized in this environment. Check logs in the Space and ensure `smolagents` is installed and HF_TOKEN is configured."
result = agent.run(question)
return result
with gr.Blocks(title="Agentic RAG Demo") as demo:
gr.Markdown(
"""
# Transformers docs QA (Agent demo)
Ask the agent a question about Hugging Face Transformers docs.
Example: *For a transformers model training, which is slower, the forward or the backward pass?*
"""
)
with gr.Row():
inp = gr.Textbox(placeholder="Write your question here...", label="Question", lines=2)
out = gr.Textbox(label="Agent answer", lines=10)
with gr.Row():
run_btn = gr.Button("Ask")
clear_btn = gr.Button("Clear")
run_btn.click(fn=run_agent, inputs=inp, outputs=out)
clear_btn.click(lambda: "", None, inp)
demo.launch()