Spaces:
Sleeping
Sleeping
| # agent.py | |
| import os | |
| import pandas as pd | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from serpapi import GoogleSearch | |
| # 1οΈβ£ Switch Graph β StateGraph | |
| from langgraph.graph import StateGraph | |
| from langchain_core.language_models.llms import LLM | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_core.tools import tool | |
| from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
| # ββββββββββββββββ | |
| # 2οΈβ£ Load & index your static FAISS docs | |
| # ββββββββββββββββ | |
| df = pd.read_csv("documents.csv") | |
| DOCS = df["content"].tolist() | |
| EMBEDDER = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") | |
| EMBS = EMBEDDER.encode(DOCS, show_progress_bar=True).astype("float32") | |
| INDEX = faiss.IndexFlatL2(EMBS.shape[1]) | |
| INDEX.add(EMBS) | |
| # ββββββββββββββββ | |
| # 3οΈβ£ Read your system prompt | |
| # ββββββββββββββββ | |
| with open("system_prompt.txt","r",encoding="utf-8") as f: | |
| SYSTEM_PROMPT = f.read().strip() | |
| # ββββββββββββββββ | |
| # 4οΈβ£ Define your tools (unchanged semantics) | |
| # ββββββββββββββββ | |
| def calculator(expr: str) -> str: | |
| try: | |
| return str(eval(expr)) | |
| except: | |
| return "Error" | |
| def retrieve_docs(query: str, k: int = 3) -> str: | |
| q_emb = EMBEDDER.encode([query]).astype("float32") | |
| D, I = INDEX.search(q_emb, k) | |
| return "\n\n---\n\n".join(DOCS[i] for i in I[0]) | |
| SERPAPI_KEY = os.getenv("SERPAPI_KEY") | |
| def web_search(query: str, num_results: int = 5) -> str: | |
| params = {"engine":"google","q":query,"num":num_results,"api_key":SERPAPI_KEY} | |
| res = GoogleSearch(params).get_dict().get("organic_results", []) | |
| return "\n".join(f"- {r.get('snippet','')}" for r in res) | |
| def wiki_search(query: str) -> str: | |
| pages = WikipediaLoader(query=query, load_max_docs=2).load() | |
| return "\n\n---\n\n".join(d.page_content for d in pages) | |
| def arxiv_search(query: str) -> str: | |
| papers = ArxivLoader(query=query, load_max_docs=3).load() | |
| return "\n\n---\n\n".join(d.page_content[:1000] for d in papers) | |
| # ββββββββββββββββ | |
| # 5οΈβ£ Define your State schema | |
| # ββββββββββββββββ | |
| from typing import TypedDict, List | |
| from langchain_core.messages import BaseMessage | |
| class AgentState(TypedDict): | |
| # Weβll carry a list of messages as our βchat historyβ | |
| messages: List[BaseMessage] | |
| # ββββββββββββββββ | |
| # 6οΈβ£ Build the StateGraph | |
| # ββββββββββββββββ | |
| def build_graph(provider: str = "huggingface") -> StateGraph: | |
| # Instantiate LLM | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| raise ValueError("HF_TOKEN missing in env") | |
| llm = LLM(provider=provider, token=hf_token, model="meta-llama/Llama-2-7b-chat-hf") | |
| # 6.1) Node: init β seed system prompt | |
| def init_node(_: AgentState) -> AgentState: | |
| return { | |
| "messages": [ | |
| SystemMessage(content=SYSTEM_PROMPT) | |
| ] | |
| } | |
| # 6.2) Node: human β append user question | |
| def human_node(state: AgentState, question: str) -> AgentState: | |
| state["messages"].append(HumanMessage(content=question)) | |
| return state | |
| # 6.3) Node: assistant β call LLM on current messages | |
| def assistant_node(state: AgentState) -> dict: | |
| ai_msg = llm.invoke(state["messages"]) | |
| return {"messages": state["messages"] + [ai_msg]} | |
| # 6.4) Optional: tool nodes (theyβll read last HumanMessage) | |
| def make_tool_node(fn): | |
| def tool_node(state: AgentState) -> dict: | |
| # fetch the latest human query | |
| last_query = state["messages"][-1].content | |
| result = fn(last_query) | |
| # append the toolβs output as if from system/Human | |
| state["messages"].append(HumanMessage(content=result)) | |
| return {"messages": state["messages"]} | |
| return tool_node | |
| # Instantiate nodes for each tool | |
| calc_node = make_tool_node(calculator) | |
| retrieve_node = make_tool_node(retrieve_docs) | |
| web_node = make_tool_node(web_search) | |
| wiki_node = make_tool_node(wiki_search) | |
| arxiv_node = make_tool_node(arxiv_search) | |
| # 6.5) Build the graph | |
| g = StateGraph(AgentState) | |
| # Register nodes | |
| g.add_node("init", init_node) | |
| g.add_node("human", human_node) | |
| g.add_node("assistant", assistant_node) | |
| g.add_node("calc", calc_node) | |
| g.add_node("retrieve", retrieve_node) | |
| g.add_node("web", web_node) | |
| g.add_node("wiki", wiki_node) | |
| g.add_node("arxiv", arxiv_node) | |
| # Wire up edges | |
| from langgraph.graph import END | |
| g.set_entry_point("init") | |
| # init β human (placeholder: weβll inject the actual question at runtime) | |
| g.add_edge("init", "human") | |
| # human β assistant | |
| g.add_edge("human", "assistant") | |
| # assistant β tool nodes (conditional on tool calls) | |
| g.add_edge("assistant", "calc") | |
| g.add_edge("assistant", "retrieve") | |
| g.add_edge("assistant", "web") | |
| g.add_edge("assistant", "wiki") | |
| g.add_edge("assistant", "arxiv") | |
| # each tool returns back into assistant for followβup | |
| g.add_edge("calc", "assistant") | |
| g.add_edge("retrieve", "assistant") | |
| g.add_edge("web", "assistant") | |
| g.add_edge("wiki", "assistant") | |
| g.add_edge("arxiv", "assistant") | |
| # and finally assistant β END when done | |
| g.add_edge("assistant", END) | |
| return g.compile() | |