Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| from langgraph.graph import START, StateGraph, MessagesState | |
| from langgraph.prebuilt import tools_condition | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_community.tools.tavily_search import TavilySearchResults # 已经导入了 | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from langchain_community.document_loaders import ArxivLoader | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_core.tools import tool | |
| # from langchain_openai import ChatOpenAI | |
| from langchain_deepseek import ChatDeepSeek | |
| # load_dotenv() # 假设你在 app.py 或其他地方加载了 .env | |
| # Ensure API keys are set | |
| DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") | |
| TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") # 需要在 Space Secrets 中添加 TAVILY_API_KEY | |
| if not DEEPSEEK_API_KEY: | |
| raise ValueError("DEEPSEEK_API_KEY not found in environment variables.") | |
| if not TAVILY_API_KEY: | |
| # Tavily is critical for most questions, raise error if not set | |
| raise ValueError("TAVILY_API_KEY not found in environment variables. Please add it to your Space Secrets.") | |
| # Keep Wikipedia and Arxiv, but the general search will be more used | |
| def wiki_search(query: str) -> str: | |
| "Using Wikipedia, search for a query and return up to 2 relevant results." | |
| try: | |
| search_docs = WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load() # Limit content length | |
| if not search_docs: | |
| return "Wikipedia search found no relevant pages." | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="Wikipedia - {doc.metadata.get("source", "")}" page="{doc.metadata.get("page", "")}">\n{doc.page_content}\n</Document>' | |
| for doc in search_docs | |
| ]) | |
| return formatted_search_docs # Return string directly | |
| except Exception as e: | |
| return f"An error occurred during Wikipedia search: {e}" | |
| # *** ADD TAVILY WEB SEARCH TOOL *** | |
| def web_search(query: str) -> str: | |
| """Search the web for a query using Tavily and return relevant snippets.""" | |
| try: | |
| tavily = TavilySearchResults(max_results=5) # Get up to 5 results | |
| results = tavily.invoke(query) | |
| if not results: | |
| return "Web search found no relevant results." | |
| # Format Tavily results | |
| formatted_results = "\n\n---\n\n".join([ | |
| f'<SearchResult source="{r["source"]}">\nTitle: {r["title"]}\nContent: {r["content"]}\n</SearchResult>' | |
| for r in results | |
| ]) | |
| return formatted_results # Return string directly | |
| except Exception as e: | |
| return f"An error occurred during web search: {e}" | |
| def duckduckgo_search(query: str) -> str: | |
| """Search the web for a query using DuckDuckGo and return relevant snippets.""" | |
| try: | |
| search_tool = DuckDuckGoSearchRun() | |
| results = search_tool.invoke(query) | |
| if not results or results.strip() == "": | |
| return "DuckDuckGo search found no relevant results." | |
| return f"<SearchResult source=\"DuckDuckGo\">{results}</SearchResult>" | |
| except Exception as e: | |
| return f"An error occurred during DuckDuckGo search: {e}" | |
| def arithmetic(expression: str) -> str: | |
| """执行数学计算并返回结果。支持基本的算术运算如加减乘除、幂运算等。""" | |
| try: | |
| # 使用Python的eval函数安全地计算表达式 | |
| # 限制只能使用基本算术运算,不允许导入模块或执行其他危险操作 | |
| allowed_names = {"__builtins__": {}} | |
| allowed_symbols = {} | |
| result = eval(expression, allowed_names, allowed_symbols) | |
| return str(result) | |
| except Exception as e: | |
| return f"计算表达式时出错: {e}" | |
| # load the system prompt from the file | |
| # Ensure this file exists and has the content from Step 2 | |
| with open("system_prompt.txt", "r", encoding="utf-8") as f: | |
| system_prompt = f.read() | |
| sys_msg = SystemMessage(content=system_prompt) | |
| tools = [ | |
| wiki_search, | |
| duckduckgo_search, | |
| web_search, | |
| arithmetic, | |
| ] | |
| def build_graph(): | |
| llm = ChatDeepSeek( | |
| model="deepseek-chat", | |
| temperature=0, # Keep low for factual answers | |
| max_tokens=None, | |
| timeout=None, | |
| max_retries=2, | |
| api_key=DEEPSEEK_API_KEY, | |
| base_url="https://api.deepseek.com" | |
| ) | |
| llm_with_tools = llm.bind_tools(tools) | |
| def assistant(state: MessagesState): | |
| """Assistant node: invoke LLM with tools.""" | |
| print("---Calling Assistant---") # Added print for debugging | |
| # 确保系统消息在消息列表的开头 | |
| messages = state["messages"] | |
| if not any(isinstance(m, SystemMessage) for m in messages): | |
| messages = [SystemMessage(content=system_prompt)] + messages | |
| result = llm_with_tools.invoke(messages) | |
| # print(f"---Assistant Response: {result}") # Added print for debugging | |
| return {"messages": [result]} | |
| builder = StateGraph(MessagesState) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(tools)) | |
| builder.add_edge(START, "assistant") | |
| # The tools_condition checks if the last message from "assistant" is a tool call. | |
| # If yes, it transitions to "tools". | |
| # If no, the graph implicitly ends. This is how the agent stops. | |
| builder.add_conditional_edges( | |
| "assistant", | |
| tools_condition, | |
| # If tool_condition is false (no tool calls detected), the default is None, | |
| # which implicitly ends the graph execution for that path. | |
| # We don't need to explicitly define other paths here for a simple graph. | |
| ) | |
| # After a tool is executed, the result is added to the state, and the control | |
| # goes back to the assistant to process the tool result and decide the next step. | |
| builder.add_edge("tools", "assistant") | |
| # You can optionally increase the recursion limit if your graph is expected to be complex, | |
| # but it's better to fix the LLM's logic via the prompt first. | |
| # return builder.compile(recursion_limit=50) # Example of increasing limit | |
| return builder.compile() | |
| if __name__ == "__main__": | |
| # Example Usage (for local testing) | |
| # To run this part, make sure you have DEEPSEEK_API_KEY and TAVILY_API_KEY | |
| # set in your environment or a .env file loaded beforehand. | |
| # If running locally, you'd typically use `load_dotenv()` here or in app.py | |
| # Test questions covering different tool needs | |
| questions_for_testing = [ | |
| "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)?", # Web Search | |
| "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species seen?", # Requires video analysis (will likely fail with current tools) | |
| ".rewsna eht sa \"tfel\" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI", # Text manipulation (no tool needed) | |
| "What is 12345 * 6789?", # Calculator | |
| "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2023?", # Web Search/Wikipedia | |
| "What country had the least number of athletes at the 1928 Summer Olympics?", # Web Search | |
| "Review the chess position provided in the image. It is black's turn. Provide the correct next move from this position: [Describe the position or mention image input which is not supported]", # Requires image analysis (will likely fail) | |
| # Add more questions from your evaluation set to test | |
| ] | |
| graph = build_graph() | |
| # Optional: Draw graph | |
| # try: | |
| # png_data = graph.get_graph().draw_mermaid_png() | |
| # with open("graph.png", "wb") as f: | |
| # f.write(png_data) | |
| # print("Graph visualization saved to graph.png") | |
| # except Exception as e: | |
| # print(f"Could not draw graph: {e}") | |
| print("\n--- Running single question tests ---") | |
| for i, question in enumerate(questions_for_testing): | |
| print(f"\n--- Testing Question {i+1}: {question}") | |
| try: | |
| # LangGraph returns the final state after execution completes or hits recursion limit | |
| final_state = graph.invoke({"messages": [SystemMessage(content=system_prompt), HumanMessage(content=question)]}) | |
| # 在这里添加您的处理答案代码 | |
| def process_answer(answer): | |
| """处理最终答案,去除可能的解释性文本""" | |
| # 如果答案包含"FINAL ANSWER:",提取实际答案部分 | |
| if "FINAL ANSWER:" in answer.upper(): | |
| import re | |
| match = re.search(r'(?i)FINAL ANSWER:\s*(.*)', answer) | |
| if match: | |
| return match.group(1).strip() | |
| # 如果答案较长且包含多个句子,尝试提取最后一句作为答案 | |
| if len(answer.split()) > 15 and "." in answer: | |
| sentences = answer.split(".") | |
| # 过滤掉空字符串 | |
| sentences = [s.strip() for s in sentences if s.strip()] | |
| if sentences: | |
| return sentences[-1].strip() | |
| return answer.strip() | |
| # 在提交答案前应用处理 | |
| final_answer = final_state["messages"][-1].content | |
| processed_answer = process_answer(final_answer) | |
| # 打印处理后的答案 | |
| print(f"\n--- Processed Answer: {processed_answer}") | |
| print("\n--- Final State Messages ---") | |
| for m in final_state["messages"]: | |
| m.pretty_print() | |
| print("-" * 30) | |
| except Exception as e: | |
| print(f"--- Error running graph for this question: {e}") | |
| print("-" * 30) |