RichardHu commited on
Commit
7606a74
·
verified ·
1 Parent(s): 3e36b8e

Update tools.py

Browse files
Files changed (1) hide show
  1. tools.py +94 -248
tools.py CHANGED
@@ -1,257 +1,103 @@
1
- from smolagents import DuckDuckGoSearchTool
2
- from smolagents import Tool
3
- import random
4
- from huggingface_hub import list_models
5
-
6
-
7
- # Initialize the DuckDuckGo search tool
8
- #search_tool = DuckDuckGoSearchTool()
9
-
10
-
11
- class WeatherInfoTool(Tool):
12
- name = "weather_info"
13
- description = "Fetches dummy weather information for a given location."
14
- inputs = {
15
- "location": {
16
- "type": "string",
17
- "description": "The location to get weather information for."
18
- }
19
- }
20
- output_type = "string"
21
-
22
- def forward(self, location: str):
23
- # Dummy weather data
24
- weather_conditions = [
25
- {"condition": "Rainy", "temp_c": 15},
26
- {"condition": "Clear", "temp_c": 25},
27
- {"condition": "Windy", "temp_c": 20}
28
- ]
29
- # Randomly select a weather condition
30
- data = random.choice(weather_conditions)
31
- return f"Weather in {location}: {data['condition']}, {data['temp_c']}°C"
32
-
33
- class HubStatsTool(Tool):
34
- name = "hub_stats"
35
- description = "Fetches the most downloaded model from a specific author on the Hugging Face Hub."
36
- inputs = {
37
- "author": {
38
- "type": "string",
39
- "description": "The username of the model author/organization to find models from."
40
- }
41
- }
42
- output_type = "string"
43
-
44
- def forward(self, author: str):
45
- try:
46
- # List models from the specified author, sorted by downloads
47
- models = list(list_models(author=author, sort="downloads", direction=-1, limit=1))
48
-
49
- if models:
50
- model = models[0]
51
- return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
52
- else:
53
- return f"No models found for author {author}."
54
- except Exception as e:
55
- return f"Error fetching models for {author}: {str(e)}"
56
-
57
-
58
- # from typing import TypedDict, List, Optional, Annotated
59
- # from langchain_core.messages import BaseMessage
60
- # from langgraph.graph import StateGraph, END
61
- # from langchain_core.prompts import ChatPromptTemplate
62
- # from langchain_openai import ChatOpenAI
63
- # from retriever import get_retriever
64
- # import json
65
 
66
- # # 定义状态对象
67
- # class GraphState(TypedDict):
68
- # question: str
69
- # documents: List[str]
70
- # answer: str
71
- # verification: Annotated[Optional[dict], "验证结果"]
72
- # retries: Annotated[int, "剩余重试次数"]
73
- # feedback: Annotated[Optional[str], "前次验证的反馈"]
74
- # history: Annotated[List[dict], "执行历史记录"]
75
 
76
- # # 初始化检索器和模型
77
- # retriever = get_retriever()
78
- # llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
79
 
80
- # def retrieve(state: GraphState):
81
- # """检索文档节点"""
82
- # history = state["history"]
83
- # history.append({"step": "检索", "status": "开始"})
84
-
85
- # question = state["question"]
86
- # documents = retriever.get_relevant_documents(question)
87
- # doc_contents = [doc.page_content for doc in documents]
88
-
89
- # history.append({
90
- # "step": "检索",
91
- # "status": "完成",
92
- # "documents": doc_contents
93
- # })
94
-
95
- # return {"documents": doc_contents, "history": history}
96
 
97
- # def generate(state: GraphState):
98
- # """生成答案节点"""
99
- # history = state["history"]
100
- # history.append({"step": "生成", "status": "开始"})
101
-
102
- # question = state["question"]
103
- # documents = state["documents"]
104
- # feedback = state.get("feedback", "")
105
-
106
- # # 构建提示词
107
- # prompt = ChatPromptTemplate.from_messages([
108
- # ("system", "你是一个专业助手,基于以下上下文回答问题。如果上下文不足,请说明。{feedback}"),
109
- # ("human", "问题:{question}\n上下文:{context}")
110
- # ])
111
-
112
- # chain = prompt | llm
113
- # context = "\n\n".join(documents)
114
- # response = chain.invoke({
115
- # "question": question,
116
- # "context": context,
117
- # "feedback": feedback
118
- # })
119
-
120
- # history.append({
121
- # "step": "生成",
122
- # "status": "完成",
123
- # "answer": response.content
124
- # })
125
-
126
- # return {"answer": response.content, "history": history}
127
-
128
- # def verify(state: GraphState):
129
- # """验证答案节点"""
130
- # history = state["history"]
131
- # history.append({"step": "验证", "status": "开始"})
132
-
133
- # question = state["question"]
134
- # answer = state["answer"]
135
- # documents = state["documents"]
136
-
137
- # # 验证提示词
138
- # prompt = ChatPromptTemplate.from_messages([
139
- # ("system", "评估答案是否符合以下标准:\n"
140
- # "1. 是否基于提供的上下文\n"
141
- # "2. 是否完整回答问题\n"
142
- # "3. 是否包含幻觉信息\n\n"
143
- # "返回JSON格式:{\"valid\": boolean, \"feedback\": string}"),
144
- # ("human", "问题:{question}\n答案:{answer}\n上下文:{context}")
145
- # ])
146
-
147
- # chain = prompt | llm
148
- # context = "\n\n".join(documents)
149
- # result = chain.invoke({
150
- # "question": question,
151
- # "answer": answer,
152
- # "context": context
153
- # })
154
-
155
- # try:
156
- # # 尝试解析JSON输出
157
- # verification = json.loads(result.content)
158
- # except:
159
- # # 如果解析失败,使用默认值
160
- # verification = {"valid": False, "feedback": "验证失败: 无法解析验证结果"}
161
-
162
- # history.append({
163
- # "step": "验证",
164
- # "status": "完成",
165
- # "verification": verification
166
- # })
167
-
168
- # return {"verification": verification, "history": history}
169
-
170
- # def should_retry(state: GraphState):
171
- # """决定是否重试的条件函数"""
172
- # history = state["history"]
173
-
174
- # if state["verification"].get("valid", False):
175
- # history.append({"step": "决策", "action": "验证通过,结束流程"})
176
- # return "end"
177
- # elif state["retries"] > 0:
178
- # history.append({
179
- # "step": "决策",
180
- # "action": f"验证失败,剩余重试次数:{state['retries']},将重试"
181
- # })
182
- # return "retry"
183
- # else:
184
- # history.append({"step": "决策", "action": "重试次数用尽,结束流程"})
185
- # return "end"
186
-
187
- # def prepare_retry(state: GraphState):
188
- # """准备重试节点"""
189
- # history = state["history"]
190
- # history.append({"step": "准备重试", "status": "开始"})
191
-
192
- # feedback = state["verification"].get("feedback", "需要改进答案")
193
-
194
- # history.append({
195
- # "step": "准备重试",
196
- # "status": "完成",
197
- # "feedback": feedback
198
- # })
199
-
200
- # return {
201
- # "feedback": feedback,
202
- # "retries": state["retries"] - 1,
203
- # "history": history
204
  # }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- # # 构建工作流
207
- # workflow = StateGraph(GraphState)
208
-
209
- # # 添加节点
210
- # workflow.add_node("retrieve", retrieve)
211
- # workflow.add_node("generate", generate)
212
- # workflow.add_node("verify", verify)
213
- # workflow.add_node("prepare_retry", prepare_retry)
214
-
215
- # # 设置入口点
216
- # workflow.set_entry_point("retrieve")
 
217
 
218
- # # 添加边
219
- # workflow.add_edge("retrieve", "generate")
220
- # workflow.add_edge("generate", "verify")
221
- # workflow.add_conditional_edges(
222
- # "verify",
223
- # should_retry,
224
- # {
225
- # "end": END,
226
- # "retry": "prepare_retry"
227
- # }
228
- # )
229
- # workflow.add_edge("prepare_retry", "retrieve")
230
 
231
- # # 编译工作流
232
- # app = workflow.compile()
 
233
 
234
- # def run_agentic_rag(question: str, max_retries: int = 3):
235
- # """运行Agentic RAG工作流"""
236
- # initial_state = {
237
- # "question": question,
238
- # "documents": [],
239
- # "answer": "",
240
- # "verification": None,
241
- # "retries": max_retries,
242
- # "feedback": "",
243
- # "history": [{"step": "初始化", "status": f"开始处理问题: {question}"}]
244
- # }
245
-
246
- # # 执行工作流
247
- # final_state = None
248
- # for step in app.stream(initial_state):
249
- # node, state = next(iter(step.items()))
250
- # final_state = state
251
-
252
- # return {
253
- # "answer": final_state["answer"],
254
- # "documents": final_state["documents"],
255
- # "history": final_state["history"],
256
- # "retries_used": max_retries - final_state["retries"]
257
- # }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from smolagents import DuckDuckGoSearchTool
2
+ # from smolagents import Tool
3
+ # import random
4
+ # from huggingface_hub import list_models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # # Initialize the DuckDuckGo search tool
8
+ # #search_tool = DuckDuckGoSearchTool()
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # class WeatherInfoTool(Tool):
12
+ # name = "weather_info"
13
+ # description = "Fetches dummy weather information for a given location."
14
+ # inputs = {
15
+ # "location": {
16
+ # "type": "string",
17
+ # "description": "The location to get weather information for."
18
+ # }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # }
20
+ # output_type = "string"
21
+
22
+ # def forward(self, location: str):
23
+ # # Dummy weather data
24
+ # weather_conditions = [
25
+ # {"condition": "Rainy", "temp_c": 15},
26
+ # {"condition": "Clear", "temp_c": 25},
27
+ # {"condition": "Windy", "temp_c": 20}
28
+ # ]
29
+ # # Randomly select a weather condition
30
+ # data = random.choice(weather_conditions)
31
+ # return f"Weather in {location}: {data['condition']}, {data['temp_c']}°C"
32
+
33
+ # class HubStatsTool(Tool):
34
+ # name = "hub_stats"
35
+ # description = "Fetches the most downloaded model from a specific author on the Hugging Face Hub."
36
+ # inputs = {
37
+ # "author": {
38
+ # "type": "string",
39
+ # "description": "The username of the model author/organization to find models from."
40
+ # }
41
+ # }
42
+ # output_type = "string"
43
 
44
+ # def forward(self, author: str):
45
+ # try:
46
+ # # List models from the specified author, sorted by downloads
47
+ # models = list(list_models(author=author, sort="downloads", direction=-1, limit=1))
48
+
49
+ # if models:
50
+ # model = models[0]
51
+ # return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
52
+ # else:
53
+ # return f"No models found for author {author}."
54
+ # except Exception as e:
55
+ # return f"Error fetching models for {author}: {str(e)}"
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ from langchain.tools import Tool
59
+ from huggingface_hub import list_models
60
+ import random
61
 
62
+ def get_weather_info(location: str) -> str:
63
+ """Fetches dummy weather information for a given location."""
64
+ # Dummy weather data
65
+ weather_conditions = [
66
+ {"condition": "Rainy", "temp_c": 15},
67
+ {"condition": "Clear", "temp_c": 25},
68
+ {"condition": "Windy", "temp_c": 20}
69
+ ]
70
+ # Randomly select a weather condition
71
+ data = random.choice(weather_conditions)
72
+ return f"Weather in {location}: {data['condition']}, {data['temp_c']}°C"
73
+
74
+ # Initialize the tool
75
+ weather_info_tool = Tool(
76
+ name="get_weather_info",
77
+ func=get_weather_info,
78
+ description="Fetches dummy weather information for a given location."
79
+ )
80
+
81
+ def get_hub_stats(author: str) -> str:
82
+ """Fetches the most downloaded model from a specific author on the Hugging Face Hub."""
83
+ try:
84
+ # List models from the specified author, sorted by downloads
85
+ models = list(list_models(author=author, sort="downloads", direction=-1, limit=1))
86
+
87
+ if models:
88
+ model = models[0]
89
+ return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
90
+ else:
91
+ return f"No models found for author {author}."
92
+ except Exception as e:
93
+ return f"Error fetching models for {author}: {str(e)}"
94
+
95
+ # Initialize the tool
96
+ hub_stats_tool = Tool(
97
+ name="get_hub_stats",
98
+ func=get_hub_stats,
99
+ description="Fetches the most downloaded model from a specific author on the Hugging Face Hub."
100
+ )
101
+
102
+ # Example usage
103
+ print(hub_stats_tool("facebook")) # Example: Get the most downloaded model by Facebook