File size: 9,992 Bytes
979ff4b
 
 
c4ed60f
 
 
979ff4b
c4ed60f
979ff4b
 
c4ed60f
979ff4b
 
909afab
c4ed60f
 
3c41942
c4ed60f
979ff4b
3c41942
 
 
c4ed60f
 
979ff4b
c4ed60f
 
 
979ff4b
3c41942
c4ed60f
3c41942
c4ed60f
3c41942
c4ed60f
3c41942
 
 
 
 
c4ed60f
3c41942
 
 
c4ed60f
 
 
3c41942
 
c4ed60f
3c41942
c4ed60f
3c41942
 
c4ed60f
 
3c41942
 
 
979ff4b
c4ed60f
3c41942
c4ed60f
 
d348bcf
 
c4ed60f
d348bcf
c4ed60f
 
 
909afab
c4ed60f
d348bcf
 
c4ed60f
979ff4b
e3e1920
 
 
 
 
 
 
 
 
 
 
 
 
c4ed60f
 
 
 
 
979ff4b
 
 
909afab
e3e1920
 
979ff4b
3c41942
c4ed60f
979ff4b
 
60ea59a
c4ed60f
3c41942
 
 
 
 
 
979ff4b
3c41942
c4ed60f
979ff4b
dc6c730
 
 
 
 
 
 
 
 
0569026
dc6c730
3c41942
979ff4b
 
3c41942
 
979ff4b
c4ed60f
 
 
 
 
 
 
 
 
 
 
 
 
 
979ff4b
3c41942
c4ed60f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3629510
0569026
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4ed60f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_community.tools.tavily_search import TavilySearchResults # 已经导入了
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.document_loaders import ArxivLoader
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
# from langchain_openai import ChatOpenAI
from langchain_deepseek import ChatDeepSeek


# load_dotenv() # 假设你在 app.py 或其他地方加载了 .env
# Ensure API keys are set
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") # 需要在 Space Secrets 中添加 TAVILY_API_KEY

if not DEEPSEEK_API_KEY:
    raise ValueError("DEEPSEEK_API_KEY not found in environment variables.")
if not TAVILY_API_KEY:
    # Tavily is critical for most questions, raise error if not set
    raise ValueError("TAVILY_API_KEY not found in environment variables. Please add it to your Space Secrets.")



# Keep Wikipedia and Arxiv, but the general search will be more used
@tool
def wiki_search(query: str) -> str:
    "Using Wikipedia, search for a query and return up to 2 relevant results."
    try:
        search_docs = WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load() # Limit content length
        if not search_docs:
             return "Wikipedia search found no relevant pages."
        formatted_search_docs = "\n\n---\n\n".join(
            [
                f'<Document source="Wikipedia - {doc.metadata.get("source", "")}" page="{doc.metadata.get("page", "")}">\n{doc.page_content}\n</Document>'
                for doc in search_docs
            ])
        return formatted_search_docs # Return string directly
    except Exception as e:
        return f"An error occurred during Wikipedia search: {e}"



# *** ADD TAVILY WEB SEARCH TOOL ***
@tool
def web_search(query: str) -> str:
    """Search the web for a query using Tavily and return relevant snippets."""
    try:
        tavily = TavilySearchResults(max_results=5) # Get up to 5 results
        results = tavily.invoke(query)
        if not results:
             return "Web search found no relevant results."
        # Format Tavily results
        formatted_results = "\n\n---\n\n".join([
            f'<SearchResult source="{r["source"]}">\nTitle: {r["title"]}\nContent: {r["content"]}\n</SearchResult>'
            for r in results
        ])
        return formatted_results # Return string directly
    except Exception as e:
        return f"An error occurred during web search: {e}"
        
@tool
def duckduckgo_search(query: str) -> str:
    """Search the web for a query using DuckDuckGo and return relevant snippets."""
    try:
        search_tool = DuckDuckGoSearchRun()
        results = search_tool.invoke(query)
        if not results or results.strip() == "":
            return "DuckDuckGo search found no relevant results."
        return f"<SearchResult source=\"DuckDuckGo\">{results}</SearchResult>"
    except Exception as e:
        return f"An error occurred during DuckDuckGo search: {e}"
        

@tool
def arithmetic(expression: str) -> str:
    """执行数学计算并返回结果。支持基本的算术运算如加减乘除、幂运算等。"""
    try:
        # 使用Python的eval函数安全地计算表达式
        # 限制只能使用基本算术运算,不允许导入模块或执行其他危险操作
        allowed_names = {"__builtins__": {}}
        allowed_symbols = {}
        result = eval(expression, allowed_names, allowed_symbols)
        return str(result)
    except Exception as e:
        return f"计算表达式时出错: {e}"
        
# load the system prompt from the file
# Ensure this file exists and has the content from Step 2
with open("system_prompt.txt", "r", encoding="utf-8") as f:
    system_prompt = f.read()
sys_msg = SystemMessage(content=system_prompt)

tools = [
    wiki_search,
    duckduckgo_search,
    web_search,
    arithmetic,
]


def build_graph():
    llm = ChatDeepSeek(
        model="deepseek-chat",
        temperature=0, # Keep low for factual answers
        max_tokens=None,
        timeout=None,
        max_retries=2,
        api_key=DEEPSEEK_API_KEY,
        base_url="https://api.deepseek.com"
    )
    llm_with_tools = llm.bind_tools(tools)


    def assistant(state: MessagesState):
        """Assistant node: invoke LLM with tools."""
        print("---Calling Assistant---") # Added print for debugging
        
        # 确保系统消息在消息列表的开头
        messages = state["messages"]
        if not any(isinstance(m, SystemMessage) for m in messages):
            messages = [SystemMessage(content=system_prompt)] + messages
        
        result = llm_with_tools.invoke(messages)
        # print(f"---Assistant Response: {result}") # Added print for debugging
        return {"messages": [result]}

    builder = StateGraph(MessagesState)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))

    builder.add_edge(START, "assistant")

    # The tools_condition checks if the last message from "assistant" is a tool call.
    # If yes, it transitions to "tools".
    # If no, the graph implicitly ends. This is how the agent stops.
    builder.add_conditional_edges(
        "assistant",
        tools_condition,
        # If tool_condition is false (no tool calls detected), the default is None,
        # which implicitly ends the graph execution for that path.
        # We don't need to explicitly define other paths here for a simple graph.
    )

    # After a tool is executed, the result is added to the state, and the control
    # goes back to the assistant to process the tool result and decide the next step.
    builder.add_edge("tools", "assistant")

    # You can optionally increase the recursion limit if your graph is expected to be complex,
    # but it's better to fix the LLM's logic via the prompt first.
    # return builder.compile(recursion_limit=50) # Example of increasing limit
    return builder.compile()


if __name__ == "__main__":
    # Example Usage (for local testing)
    # To run this part, make sure you have DEEPSEEK_API_KEY and TAVILY_API_KEY
    # set in your environment or a .env file loaded beforehand.
    # If running locally, you'd typically use `load_dotenv()` here or in app.py

    # Test questions covering different tool needs
    questions_for_testing = [
        "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)?", # Web Search
        "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species seen?", # Requires video analysis (will likely fail with current tools)
        ".rewsna eht sa \"tfel\" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI", # Text manipulation (no tool needed)
        "What is 12345 * 6789?", # Calculator
        "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2023?", # Web Search/Wikipedia
        "What country had the least number of athletes at the 1928 Summer Olympics?", # Web Search
        "Review the chess position provided in the image. It is black's turn. Provide the correct next move from this position: [Describe the position or mention image input which is not supported]", # Requires image analysis (will likely fail)
        # Add more questions from your evaluation set to test
    ]


    graph = build_graph()

    # Optional: Draw graph
    # try:
    #     png_data = graph.get_graph().draw_mermaid_png()
    #     with open("graph.png", "wb") as f:
    #         f.write(png_data)
    #     print("Graph visualization saved to graph.png")
    # except Exception as e:
    #      print(f"Could not draw graph: {e}")


    print("\n--- Running single question tests ---")
    for i, question in enumerate(questions_for_testing):
        print(f"\n--- Testing Question {i+1}: {question}")
        try:
            # LangGraph returns the final state after execution completes or hits recursion limit
            final_state = graph.invoke({"messages": [SystemMessage(content=system_prompt), HumanMessage(content=question)]})
            
            # 在这里添加您的处理答案代码
            def process_answer(answer): 
                """处理最终答案,去除可能的解释性文本""" 
                # 如果答案包含"FINAL ANSWER:",提取实际答案部分 
                if "FINAL ANSWER:" in answer.upper(): 
                    import re 
                    match = re.search(r'(?i)FINAL ANSWER:\s*(.*)', answer) 
                    if match: 
                        return match.group(1).strip() 
                
                # 如果答案较长且包含多个句子,尝试提取最后一句作为答案 
                if len(answer.split()) > 15 and "." in answer: 
                    sentences = answer.split(".") 
                    # 过滤掉空字符串 
                    sentences = [s.strip() for s in sentences if s.strip()] 
                    if sentences: 
                        return sentences[-1].strip() 
                
                return answer.strip() 
            
            # 在提交答案前应用处理 
            final_answer = final_state["messages"][-1].content 
            processed_answer = process_answer(final_answer) 
            # 打印处理后的答案
            print(f"\n--- Processed Answer: {processed_answer}")
        
            print("\n--- Final State Messages ---")
            for m in final_state["messages"]:
                 m.pretty_print()
            print("-" * 30)
        except Exception as e:
             print(f"--- Error running graph for this question: {e}")
             print("-" * 30)