RichardHu commited on
Commit
0295215
·
verified ·
1 Parent(s): 525b8d4

Update tools.py

Browse files
Files changed (1) hide show
  1. tools.py +214 -214
tools.py CHANGED
@@ -1,257 +1,257 @@
1
- # from smolagents import DuckDuckGoSearchTool
2
- # from smolagents import Tool
3
- # import random
4
- # from huggingface_hub import list_models
5
 
6
 
7
- # # Initialize the DuckDuckGo search tool
8
- # #search_tool = DuckDuckGoSearchTool()
9
 
10
 
11
- # class WeatherInfoTool(Tool):
12
- # name = "weather_info"
13
- # description = "Fetches dummy weather information for a given location."
14
- # inputs = {
15
- # "location": {
16
- # "type": "string",
17
- # "description": "The location to get weather information for."
18
- # }
19
- # }
20
- # output_type = "string"
21
 
22
- # def forward(self, location: str):
23
- # # Dummy weather data
24
- # weather_conditions = [
25
- # {"condition": "Rainy", "temp_c": 15},
26
- # {"condition": "Clear", "temp_c": 25},
27
- # {"condition": "Windy", "temp_c": 20}
28
- # ]
29
- # # Randomly select a weather condition
30
- # data = random.choice(weather_conditions)
31
- # return f"Weather in {location}: {data['condition']}, {data['temp_c']}°C"
32
 
33
- # class HubStatsTool(Tool):
34
- # name = "hub_stats"
35
- # description = "Fetches the most downloaded model from a specific author on the Hugging Face Hub."
36
- # inputs = {
37
- # "author": {
38
- # "type": "string",
39
- # "description": "The username of the model author/organization to find models from."
40
- # }
41
- # }
42
- # output_type = "string"
43
 
44
- # def forward(self, author: str):
45
- # try:
46
- # # List models from the specified author, sorted by downloads
47
- # models = list(list_models(author=author, sort="downloads", direction=-1, limit=1))
48
 
49
- # if models:
50
- # model = models[0]
51
- # return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
52
- # else:
53
- # return f"No models found for author {author}."
54
- # except Exception as e:
55
- # return f"Error fetching models for {author}: {str(e)}"
56
 
57
 
58
- from typing import TypedDict, List, Optional, Annotated
59
- from langchain_core.messages import BaseMessage
60
- from langgraph.graph import StateGraph, END
61
- from langchain_core.prompts import ChatPromptTemplate
62
- from langchain_openai import ChatOpenAI
63
- from retriever import get_retriever
64
- import json
65
 
66
- # 定义状态对象
67
- class GraphState(TypedDict):
68
- question: str
69
- documents: List[str]
70
- answer: str
71
- verification: Annotated[Optional[dict], "验证结果"]
72
- retries: Annotated[int, "剩余重试次数"]
73
- feedback: Annotated[Optional[str], "前次验证的反馈"]
74
- history: Annotated[List[dict], "执行历史记录"]
75
 
76
- # 初始化检索器和模型
77
- retriever = get_retriever()
78
- llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
79
 
80
- def retrieve(state: GraphState):
81
- """检索文档节点"""
82
- history = state["history"]
83
- history.append({"step": "检索", "status": "开始"})
84
 
85
- question = state["question"]
86
- documents = retriever.get_relevant_documents(question)
87
- doc_contents = [doc.page_content for doc in documents]
88
 
89
- history.append({
90
- "step": "检索",
91
- "status": "完成",
92
- "documents": doc_contents
93
- })
94
 
95
- return {"documents": doc_contents, "history": history}
96
 
97
- def generate(state: GraphState):
98
- """生成答案节点"""
99
- history = state["history"]
100
- history.append({"step": "生成", "status": "开始"})
101
 
102
- question = state["question"]
103
- documents = state["documents"]
104
- feedback = state.get("feedback", "")
105
 
106
- # 构建提示词
107
- prompt = ChatPromptTemplate.from_messages([
108
- ("system", "你是一个专业助手,基于以下上下文回答问题。如果上下文不足,请说明。{feedback}"),
109
- ("human", "问题:{question}\n上下文:{context}")
110
- ])
111
 
112
- chain = prompt | llm
113
- context = "\n\n".join(documents)
114
- response = chain.invoke({
115
- "question": question,
116
- "context": context,
117
- "feedback": feedback
118
- })
119
 
120
- history.append({
121
- "step": "生成",
122
- "status": "完成",
123
- "answer": response.content
124
- })
125
 
126
- return {"answer": response.content, "history": history}
127
 
128
- def verify(state: GraphState):
129
- """验证答案节点"""
130
- history = state["history"]
131
- history.append({"step": "验证", "status": "开始"})
132
 
133
- question = state["question"]
134
- answer = state["answer"]
135
- documents = state["documents"]
136
 
137
- # 验证提示词
138
- prompt = ChatPromptTemplate.from_messages([
139
- ("system", "评估答案是否符合以下标准:\n"
140
- "1. 是否基于提供的上下文\n"
141
- "2. 是否完整回答问题\n"
142
- "3. 是否包含幻觉信息\n\n"
143
- "返回JSON格式:{\"valid\": boolean, \"feedback\": string}"),
144
- ("human", "问题:{question}\n答案:{answer}\n上下文:{context}")
145
- ])
146
 
147
- chain = prompt | llm
148
- context = "\n\n".join(documents)
149
- result = chain.invoke({
150
- "question": question,
151
- "answer": answer,
152
- "context": context
153
- })
154
 
155
- try:
156
- # 尝试解析JSON输出
157
- verification = json.loads(result.content)
158
- except:
159
- # 如果解析失败,使用默认值
160
- verification = {"valid": False, "feedback": "验证失败: 无法解析验证结果"}
161
 
162
- history.append({
163
- "step": "验证",
164
- "status": "完成",
165
- "verification": verification
166
- })
167
 
168
- return {"verification": verification, "history": history}
169
 
170
- def should_retry(state: GraphState):
171
- """决定是否重试的条件函数"""
172
- history = state["history"]
173
 
174
- if state["verification"].get("valid", False):
175
- history.append({"step": "决策", "action": "验证通过,结束流程"})
176
- return "end"
177
- elif state["retries"] > 0:
178
- history.append({
179
- "step": "决策",
180
- "action": f"验证失败,剩余重试次数:{state['retries']},将重试"
181
- })
182
- return "retry"
183
- else:
184
- history.append({"step": "决策", "action": "重试次数用尽,结束流程"})
185
- return "end"
186
 
187
- def prepare_retry(state: GraphState):
188
- """准备重试节点"""
189
- history = state["history"]
190
- history.append({"step": "准备重试", "status": "开始"})
191
 
192
- feedback = state["verification"].get("feedback", "需要改进答案")
193
 
194
- history.append({
195
- "step": "准备重试",
196
- "status": "完成",
197
- "feedback": feedback
198
- })
199
 
200
- return {
201
- "feedback": feedback,
202
- "retries": state["retries"] - 1,
203
- "history": history
204
- }
205
 
206
- # 构建工作流
207
- workflow = StateGraph(GraphState)
208
 
209
- # 添加节点
210
- workflow.add_node("retrieve", retrieve)
211
- workflow.add_node("generate", generate)
212
- workflow.add_node("verify", verify)
213
- workflow.add_node("prepare_retry", prepare_retry)
214
 
215
- # 设置入口点
216
- workflow.set_entry_point("retrieve")
217
 
218
- # 添加边
219
- workflow.add_edge("retrieve", "generate")
220
- workflow.add_edge("generate", "verify")
221
- workflow.add_conditional_edges(
222
- "verify",
223
- should_retry,
224
- {
225
- "end": END,
226
- "retry": "prepare_retry"
227
- }
228
- )
229
- workflow.add_edge("prepare_retry", "retrieve")
230
 
231
- # 编译工作流
232
- app = workflow.compile()
233
 
234
- def run_agentic_rag(question: str, max_retries: int = 3):
235
- """运行Agentic RAG工作流"""
236
- initial_state = {
237
- "question": question,
238
- "documents": [],
239
- "answer": "",
240
- "verification": None,
241
- "retries": max_retries,
242
- "feedback": "",
243
- "history": [{"step": "初始化", "status": f"开始处理问题: {question}"}]
244
- }
245
 
246
- # 执行工作流
247
- final_state = None
248
- for step in app.stream(initial_state):
249
- node, state = next(iter(step.items()))
250
- final_state = state
251
 
252
- return {
253
- "answer": final_state["answer"],
254
- "documents": final_state["documents"],
255
- "history": final_state["history"],
256
- "retries_used": max_retries - final_state["retries"]
257
- }
 
1
+ from smolagents import DuckDuckGoSearchTool
2
+ from smolagents import Tool
3
+ import random
4
+ from huggingface_hub import list_models
5
 
6
 
7
+ # Initialize the DuckDuckGo search tool
8
+ #search_tool = DuckDuckGoSearchTool()
9
 
10
 
11
+ class WeatherInfoTool(Tool):
12
+ name = "weather_info"
13
+ description = "Fetches dummy weather information for a given location."
14
+ inputs = {
15
+ "location": {
16
+ "type": "string",
17
+ "description": "The location to get weather information for."
18
+ }
19
+ }
20
+ output_type = "string"
21
 
22
+ def forward(self, location: str):
23
+ # Dummy weather data
24
+ weather_conditions = [
25
+ {"condition": "Rainy", "temp_c": 15},
26
+ {"condition": "Clear", "temp_c": 25},
27
+ {"condition": "Windy", "temp_c": 20}
28
+ ]
29
+ # Randomly select a weather condition
30
+ data = random.choice(weather_conditions)
31
+ return f"Weather in {location}: {data['condition']}, {data['temp_c']}°C"
32
 
33
+ class HubStatsTool(Tool):
34
+ name = "hub_stats"
35
+ description = "Fetches the most downloaded model from a specific author on the Hugging Face Hub."
36
+ inputs = {
37
+ "author": {
38
+ "type": "string",
39
+ "description": "The username of the model author/organization to find models from."
40
+ }
41
+ }
42
+ output_type = "string"
43
 
44
+ def forward(self, author: str):
45
+ try:
46
+ # List models from the specified author, sorted by downloads
47
+ models = list(list_models(author=author, sort="downloads", direction=-1, limit=1))
48
 
49
+ if models:
50
+ model = models[0]
51
+ return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
52
+ else:
53
+ return f"No models found for author {author}."
54
+ except Exception as e:
55
+ return f"Error fetching models for {author}: {str(e)}"
56
 
57
 
58
+ # from typing import TypedDict, List, Optional, Annotated
59
+ # from langchain_core.messages import BaseMessage
60
+ # from langgraph.graph import StateGraph, END
61
+ # from langchain_core.prompts import ChatPromptTemplate
62
+ # from langchain_openai import ChatOpenAI
63
+ # from retriever import get_retriever
64
+ # import json
65
 
66
+ # # 定义状态对象
67
+ # class GraphState(TypedDict):
68
+ # question: str
69
+ # documents: List[str]
70
+ # answer: str
71
+ # verification: Annotated[Optional[dict], "验证结果"]
72
+ # retries: Annotated[int, "剩余重试次数"]
73
+ # feedback: Annotated[Optional[str], "前次验证的反馈"]
74
+ # history: Annotated[List[dict], "执行历史记录"]
75
 
76
+ # # 初始化检索器和模型
77
+ # retriever = get_retriever()
78
+ # llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
79
 
80
+ # def retrieve(state: GraphState):
81
+ # """检索文档节点"""
82
+ # history = state["history"]
83
+ # history.append({"step": "检索", "status": "开始"})
84
 
85
+ # question = state["question"]
86
+ # documents = retriever.get_relevant_documents(question)
87
+ # doc_contents = [doc.page_content for doc in documents]
88
 
89
+ # history.append({
90
+ # "step": "检索",
91
+ # "status": "完成",
92
+ # "documents": doc_contents
93
+ # })
94
 
95
+ # return {"documents": doc_contents, "history": history}
96
 
97
+ # def generate(state: GraphState):
98
+ # """生成答案节点"""
99
+ # history = state["history"]
100
+ # history.append({"step": "生成", "status": "开始"})
101
 
102
+ # question = state["question"]
103
+ # documents = state["documents"]
104
+ # feedback = state.get("feedback", "")
105
 
106
+ # # 构建提示词
107
+ # prompt = ChatPromptTemplate.from_messages([
108
+ # ("system", "你是一个专业助手,基于以下上下文回答问题。如果上下文不足,请说明。{feedback}"),
109
+ # ("human", "问题:{question}\n上下文:{context}")
110
+ # ])
111
 
112
+ # chain = prompt | llm
113
+ # context = "\n\n".join(documents)
114
+ # response = chain.invoke({
115
+ # "question": question,
116
+ # "context": context,
117
+ # "feedback": feedback
118
+ # })
119
 
120
+ # history.append({
121
+ # "step": "生成",
122
+ # "status": "完成",
123
+ # "answer": response.content
124
+ # })
125
 
126
+ # return {"answer": response.content, "history": history}
127
 
128
+ # def verify(state: GraphState):
129
+ # """验证答案节点"""
130
+ # history = state["history"]
131
+ # history.append({"step": "验证", "status": "开始"})
132
 
133
+ # question = state["question"]
134
+ # answer = state["answer"]
135
+ # documents = state["documents"]
136
 
137
+ # # 验证提示词
138
+ # prompt = ChatPromptTemplate.from_messages([
139
+ # ("system", "评估答案是否符合以下标准:\n"
140
+ # "1. 是否基于提供的上下文\n"
141
+ # "2. 是否完整回答问题\n"
142
+ # "3. 是否包含幻觉信息\n\n"
143
+ # "返回JSON格式:{\"valid\": boolean, \"feedback\": string}"),
144
+ # ("human", "问题:{question}\n答案:{answer}\n上下文:{context}")
145
+ # ])
146
 
147
+ # chain = prompt | llm
148
+ # context = "\n\n".join(documents)
149
+ # result = chain.invoke({
150
+ # "question": question,
151
+ # "answer": answer,
152
+ # "context": context
153
+ # })
154
 
155
+ # try:
156
+ # # 尝试解析JSON输出
157
+ # verification = json.loads(result.content)
158
+ # except:
159
+ # # 如果解析失败,使用默认值
160
+ # verification = {"valid": False, "feedback": "验证失败: 无法解析验证结果"}
161
 
162
+ # history.append({
163
+ # "step": "验证",
164
+ # "status": "完成",
165
+ # "verification": verification
166
+ # })
167
 
168
+ # return {"verification": verification, "history": history}
169
 
170
+ # def should_retry(state: GraphState):
171
+ # """决定是否重试的条件函数"""
172
+ # history = state["history"]
173
 
174
+ # if state["verification"].get("valid", False):
175
+ # history.append({"step": "决策", "action": "验证通过,结束流程"})
176
+ # return "end"
177
+ # elif state["retries"] > 0:
178
+ # history.append({
179
+ # "step": "决策",
180
+ # "action": f"验证失败,剩余重试次数:{state['retries']},将重试"
181
+ # })
182
+ # return "retry"
183
+ # else:
184
+ # history.append({"step": "决策", "action": "重试次数用尽,结束流程"})
185
+ # return "end"
186
 
187
+ # def prepare_retry(state: GraphState):
188
+ # """准备重试节点"""
189
+ # history = state["history"]
190
+ # history.append({"step": "准备重试", "status": "开始"})
191
 
192
+ # feedback = state["verification"].get("feedback", "需要改进答案")
193
 
194
+ # history.append({
195
+ # "step": "准备重试",
196
+ # "status": "完成",
197
+ # "feedback": feedback
198
+ # })
199
 
200
+ # return {
201
+ # "feedback": feedback,
202
+ # "retries": state["retries"] - 1,
203
+ # "history": history
204
+ # }
205
 
206
+ # # 构建工作流
207
+ # workflow = StateGraph(GraphState)
208
 
209
+ # # 添加节点
210
+ # workflow.add_node("retrieve", retrieve)
211
+ # workflow.add_node("generate", generate)
212
+ # workflow.add_node("verify", verify)
213
+ # workflow.add_node("prepare_retry", prepare_retry)
214
 
215
+ # # 设置入口点
216
+ # workflow.set_entry_point("retrieve")
217
 
218
+ # # 添加边
219
+ # workflow.add_edge("retrieve", "generate")
220
+ # workflow.add_edge("generate", "verify")
221
+ # workflow.add_conditional_edges(
222
+ # "verify",
223
+ # should_retry,
224
+ # {
225
+ # "end": END,
226
+ # "retry": "prepare_retry"
227
+ # }
228
+ # )
229
+ # workflow.add_edge("prepare_retry", "retrieve")
230
 
231
+ # # 编译工作流
232
+ # app = workflow.compile()
233
 
234
+ # def run_agentic_rag(question: str, max_retries: int = 3):
235
+ # """运行Agentic RAG工作流"""
236
+ # initial_state = {
237
+ # "question": question,
238
+ # "documents": [],
239
+ # "answer": "",
240
+ # "verification": None,
241
+ # "retries": max_retries,
242
+ # "feedback": "",
243
+ # "history": [{"step": "初始化", "status": f"开始处理问题: {question}"}]
244
+ # }
245
 
246
+ # # 执行工作流
247
+ # final_state = None
248
+ # for step in app.stream(initial_state):
249
+ # node, state = next(iter(step.items()))
250
+ # final_state = state
251
 
252
+ # return {
253
+ # "answer": final_state["answer"],
254
+ # "documents": final_state["documents"],
255
+ # "history": final_state["history"],
256
+ # "retries_used": max_retries - final_state["retries"]
257
+ # }