Spaces:
Running
Running
| """ | |
| Gradio demo that exposes your agentic QA pipeline (uses smolagents CodeAgent + a BM25 retriever). | |
| Intended for deployment to Hugging Face Spaces. Set HF_TOKEN in Space secrets or environment. | |
| """ | |
| import os | |
| import traceback | |
| import gradio as gr | |
| # Basic ML / NLP libs used by your pipeline | |
| import datasets | |
| from langchain.docstore.document import Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.retrievers import BM25Retriever | |
| # smolagents agent pieces | |
| from smolagents import Tool, InferenceClientModel, CodeAgent | |
| # ------------------------- | |
| # Document preparation | |
| # ------------------------- | |
| def prepare_knowledge_base(cache_dir="/tmp/hf_kb_cache"): | |
| """ | |
| Download and prepare the HF docs dataset, filter to transformers docs, | |
| chunk into smaller documents and return the processed doc list. | |
| This function caches results across runs (simple file-check). | |
| """ | |
| import os | |
| import pickle | |
| cache_path = os.path.join(cache_dir, "docs_processed.pkl") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # If cached, load and return | |
| if os.path.exists(cache_path): | |
| try: | |
| with open(cache_path, "rb") as f: | |
| docs_processed = pickle.load(f) | |
| return docs_processed | |
| except Exception: | |
| # fall through to re-create cache | |
| pass | |
| knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train") | |
| # Keep only transformers docs (same filter as your original snippet) | |
| knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers")) | |
| source_docs = [ | |
| Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) | |
| for doc in knowledge_base | |
| ] | |
| # Split into chunks | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=50, | |
| add_start_index=True, | |
| strip_whitespace=True, | |
| separators=["\n\n", "\n", ".", " ", ""], | |
| ) | |
| docs_processed = text_splitter.split_documents(source_docs) | |
| import pickle | |
| with open(cache_path, "wb") as f: | |
| pickle.dump(docs_processed, f) | |
| return docs_processed | |
| # ------------------------- | |
| # Retriever tool for agent | |
| # ------------------------- | |
| class RetrieverTool(Tool): | |
| name = "retriever" | |
| description = "Uses BM25 retrieval over transformers docs to fetch context relevant to a question." | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "A short query describing the information to retrieve (affirmative form).", | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self, docs, k=8, **kwargs): | |
| super().__init__(**kwargs) | |
| # Build a BM25 retriever from the processed docs | |
| self.retriever = BM25Retriever.from_documents(docs, k=k) | |
| def forward(self, query: str) -> str: | |
| assert isinstance(query, str), "query must be a string" | |
| docs = self.retriever.invoke(query) | |
| formatted = "\nRetrieved documents:\n" + "".join( | |
| [ | |
| f"\n\n===== Document {i} =====\n{doc.page_content}" | |
| for i, doc in enumerate(docs) | |
| ] | |
| ) | |
| return formatted | |
| # ------------------------- | |
| # Agent initialization | |
| # ------------------------- | |
| # Prepare docs | |
| DOCS = prepare_knowledge_base() | |
| # Initialize tool instance | |
| retriever_tool = RetrieverTool(DOCS) | |
| # NOTE: On HF Spaces you can set environment variable HF_TOKEN in the UI (Settings -> Secrets). | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if HF_TOKEN is None: | |
| print("Warning: HF_TOKEN not set. If your chosen model requires authentication, set HF_TOKEN in environment/secrets.") | |
| model = InferenceClientModel() # default model; you can set model_id arg if needed | |
| agent = CodeAgent( | |
| tools=[retriever_tool], | |
| model=model, | |
| max_steps=4, | |
| verbosity_level=1, | |
| ) | |
| # ------------------------- | |
| # Gradio interface | |
| # ------------------------- | |
| def run_agent(question: str): | |
| """Run the agent and return the final answer (or a helpful error).""" | |
| if not question or question.strip() == "": | |
| return "Please enter a question." | |
| # If agent couldn't be created, return fallback info | |
| if agent is None: | |
| return "Agent not initialized in this environment. Check logs in the Space and ensure `smolagents` is installed and HF_TOKEN is configured." | |
| result = agent.run(question) | |
| return result | |
| with gr.Blocks(title="Agentic RAG Demo") as demo: | |
| gr.Markdown( | |
| """ | |
| # Transformers docs QA (Agent demo) | |
| Ask the agent a question about Hugging Face Transformers docs. | |
| Example: *For a transformers model training, which is slower, the forward or the backward pass?* | |
| """ | |
| ) | |
| with gr.Row(): | |
| inp = gr.Textbox(placeholder="Write your question here...", label="Question", lines=2) | |
| out = gr.Textbox(label="Agent answer", lines=10) | |
| with gr.Row(): | |
| run_btn = gr.Button("Ask") | |
| clear_btn = gr.Button("Clear") | |
| run_btn.click(fn=run_agent, inputs=inp, outputs=out) | |
| clear_btn.click(lambda: "", None, inp) | |
| demo.launch() | |