Spaces:
Sleeping
Sleeping
| # agent.py | |
| import os | |
| import pandas as pd | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from serpapi import GoogleSearch | |
| # 1οΈβ£ Switch Graph β StateGraph | |
| from langgraph.graph import StateGraph | |
| #from langchain_core.language_models.llms import LLM | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_core.tools import tool | |
| from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
| # ββββββββββββββββ | |
| # 2οΈβ£ Load & index your static FAISS docs | |
| # ββββββββββββββββ | |
| df = pd.read_csv("documents.csv") | |
| DOCS = df["content"].tolist() | |
| EMBEDDER = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") | |
| EMBS = EMBEDDER.encode(DOCS, show_progress_bar=True).astype("float32") | |
| INDEX = faiss.IndexFlatL2(EMBS.shape[1]) | |
| INDEX.add(EMBS) | |
| # ββββββββββββββββ | |
| # 3οΈβ£ Read your system prompt | |
| # ββββββββββββββββ | |
| with open("system_prompt.txt","r",encoding="utf-8") as f: | |
| SYSTEM_PROMPT = f.read().strip() | |
| # ββββββββββββββββ | |
| # 4οΈβ£ Define your tools (with docstrings) | |
| # ββββββββββββββββ | |
| def calculator(expr: str) -> str: | |
| """ | |
| Evaluate the given Python expression and return its result as a string. | |
| Returns "Error" if evaluation fails. | |
| """ | |
| try: | |
| return str(eval(expr)) | |
| except Exception: | |
| return "Error" | |
| def retrieve_docs(query: str, k: int = 3) -> str: | |
| """ | |
| Perform vector similarity search over the FAISS index. | |
| Args: | |
| query: the userβs query string to embed and search for. | |
| k: the number of nearest documents to return (default 3). | |
| Returns: | |
| The top-k document contents concatenated into one string. | |
| """ | |
| q_emb = EMBEDDER.encode([query]).astype("float32") | |
| D, I = INDEX.search(q_emb, k) | |
| return "\n\n---\n\n".join(DOCS[i] for i in I[0]) | |
| SERPAPI_KEY = os.getenv("SERPAPI_KEY") | |
| def web_search(query: str, num_results: int = 5) -> str: | |
| """ | |
| Run a Google search via SerpAPI and return the top snippets. | |
| Args: | |
| query: the search query. | |
| num_results: how many results to fetch (default 5). | |
| Returns: | |
| A newline-separated list of snippet strings. | |
| """ | |
| params = { | |
| "engine": "google", | |
| "q": query, | |
| "num": num_results, | |
| "api_key": SERPAPI_KEY, | |
| } | |
| res = GoogleSearch(params).get_dict().get("organic_results", []) | |
| return "\n".join(f"- {r.get('snippet','')}" for r in res) | |
| def wiki_search(query: str) -> str: | |
| """ | |
| Search Wikipedia for up to 2 pages matching `query`. | |
| Args: | |
| query: the topic to look up on Wikipedia. | |
| Returns: | |
| The combined page contents of the top-2 Wikipedia results. | |
| """ | |
| pages = WikipediaLoader(query=query, load_max_docs=2).load() | |
| return "\n\n---\n\n".join(d.page_content for d in pages) | |
| def arxiv_search(query: str) -> str: | |
| """ | |
| Search ArXiv for up to 3 papers matching `query` and return abstracts. | |
| Args: | |
| query: the search query for ArXiv. | |
| Returns: | |
| The first 1000 characters of each of the top-3 ArXiv abstracts. | |
| """ | |
| papers = ArxivLoader(query=query, load_max_docs=3).load() | |
| return "\n\n---\n\n".join(d.page_content[:1000] for d in papers) | |
| # ββββββββββββββββ | |
| # 5οΈβ£ Define your State schema | |
| # ββββββββββββββββ | |
| from typing import TypedDict, List | |
| from langchain_core.messages import BaseMessage | |
| class AgentState(TypedDict): | |
| # Weβll carry a list of messages as our βchat historyβ | |
| messages: List[BaseMessage] | |
| # ββββββββββββββββ | |
| # 6οΈβ£ Build the StateGraph | |
| # ββββββββββββββββ | |
| GROQ_API_KEY=os.getenv("GROQ_API_KEY") | |
| def build_graph(provider: str = "groq") -> StateGraph: | |
| llm = ChatGroq(model="qwen-qwq-32b", temperature=0) | |
| # 6.1) Node: init β seed system prompt | |
| def init_node(_: AgentState) -> AgentState: | |
| return { | |
| "messages": [ | |
| SystemMessage(content=SYSTEM_PROMPT) | |
| ] | |
| } | |
| # 6.2) Node: human β stash the GAIA task ID, then append the question | |
| def human_node(state: AgentState, id: str, question: str) -> AgentState: | |
| # keep the GAIA task id so we can submit it later | |
| state["task_id"] = task_id | |
| state["messages"].append(HumanMessage(content=question)) | |
| return state | |
| # 6.3) Node: assistant β call LLM on current messages | |
| def assistant_node(state: AgentState) -> dict: | |
| ai_msg = llm.invoke(state["messages"]) | |
| return {"messages": state["messages"] + [ai_msg]} | |
| # 6.4) Optional: tool nodes (theyβll read last HumanMessage) | |
| def make_tool_node(fn): | |
| def tool_node(state: AgentState) -> dict: | |
| # fetch the latest human query | |
| last_query = state["messages"][-1].content | |
| result = fn(last_query) | |
| # append the toolβs output as if from system/Human | |
| state["messages"].append(HumanMessage(content=result)) | |
| return {"messages": state["messages"]} | |
| return tool_node | |
| # 6.5) Node: answer β pull out the last assistant reply & format submission dict | |
| def answer_node(state: AgentState) -> dict[str,str]: | |
| # the GAIA runner will do `.items()` on whatever you return here | |
| tid = state["task_id"] | |
| # grab the last message (could be a BaseMessage or a raw str) | |
| last = state["messages"][-1] | |
| text = getattr(last, "content", None) or str(last) | |
| return { tid: text } | |
| # Instantiate nodes for each tool | |
| calc_node = make_tool_node(calculator) | |
| retrieve_node = make_tool_node(retrieve_docs) | |
| web_node = make_tool_node(web_search) | |
| wiki_node = make_tool_node(wiki_search) | |
| arxiv_node = make_tool_node(arxiv_search) | |
| # 6.5) Build the graph | |
| g = StateGraph(AgentState) | |
| # Register nodes | |
| g.add_node("init", init_node) | |
| g.add_node("human", human_node) | |
| g.add_node("assistant", assistant_node) | |
| g.add_node("calc", calc_node) | |
| g.add_node("retrieve", retrieve_node) | |
| g.add_node("web", web_node) | |
| g.add_node("wiki", wiki_node) | |
| g.add_node("arxiv", arxiv_node) | |
| # Wire up edges | |
| from langgraph.graph import END | |
| g.set_entry_point("init") | |
| # init β human (placeholder: weβll inject the actual question at runtime) | |
| g.add_edge("init", "human") | |
| # human β assistant | |
| g.add_edge("human", "assistant") | |
| # assistant β tool nodes (conditional on tool calls) | |
| g.add_edge("assistant", "calc") | |
| g.add_edge("assistant", "retrieve") | |
| g.add_edge("assistant", "web") | |
| g.add_edge("assistant", "wiki") | |
| g.add_edge("assistant", "arxiv") | |
| # each tool returns back into assistant for followβup | |
| g.add_edge("calc", "assistant") | |
| g.add_edge("retrieve", "assistant") | |
| g.add_edge("web", "assistant") | |
| g.add_edge("wiki", "assistant") | |
| g.add_edge("arxiv", "assistant") | |
| # register & wire your new answer node | |
| g.add_node("answer", answer_node) | |
| # send assistant β answer β END | |
| g.add_edge("assistant", "answer") | |
| g.add_edge("answer", END) | |
| return g.compile() | |