mattibuzzo13 commited on
Commit
b6f4cbf
·
verified ·
1 Parent(s): 59de913

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -21
app.py CHANGED
@@ -7,10 +7,11 @@ import re
7
  import math
8
  import json
9
  import unicodedata
10
- from typing import TypedDict, Annotated
11
 
12
- from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
13
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
 
14
  from langchain_core.tools import tool
15
  from langchain_community.tools import DuckDuckGoSearchRun
16
  from langchain_community.utilities import WikipediaAPIWrapper
@@ -89,10 +90,11 @@ def get_task_file(task_id: str) -> str:
89
  ct = response.headers.get("Content-Type", "")
90
  if "text" in ct or "json" in ct:
91
  return response.text[:5000]
92
- return f"[Binary file attached - content-type: {ct}]"
93
  return f"No file found for task {task_id}"
94
  except Exception as e:
95
  return f"Error fetching task file: {e}"
 
96
 
97
 
98
  class AgentState(TypedDict):
@@ -117,14 +119,24 @@ Always end with:
117
  FINAL ANSWER: <your answer here>
118
  """
119
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # --- Basic Agent Definition ---
122
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
123
  class BasicAgent:
124
  def __init__(self):
125
- print("Initializing LangGraph Agent (HF course pattern)...")
126
 
127
- tools = [
128
  web_search,
129
  wikipedia_search,
130
  python_repl,
@@ -132,22 +144,22 @@ class BasicAgent:
132
  get_task_file,
133
  ]
134
 
135
- # Pattern esatto dal notebook del corso:
136
- # llm.bind_tools() + StateGraph costruito manualmente
137
- endpoint = HuggingFaceEndpoint(
138
- repo_id="Qwen/Qwen2.5-72B-Instruct", # tool calling nativo
139
- huggingfacehub_api_token=os.getenv("HF_TOKEN"),
140
- task="conversational",
141
- max_new_tokens=1024,
142
- temperature=0.1,
143
  )
144
- llm = ChatHuggingFace(llm=endpoint, verbose=False)
145
- self.llm_with_tools = llm.bind_tools(tools, parallel_tool_calls=False)
146
 
147
- # Grafo ReAct: assistant tools assistant (loop)
 
 
 
148
  builder = StateGraph(AgentState)
149
  builder.add_node("assistant", self._assistant_node)
150
- builder.add_node("tools", ToolNode(tools))
151
  builder.add_edge(START, "assistant")
152
  builder.add_conditional_edges("assistant", tools_condition)
153
  builder.add_edge("tools", "assistant")
@@ -155,11 +167,69 @@ class BasicAgent:
155
 
156
  print("Agent ready.")
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def _assistant_node(self, state: AgentState):
159
- """Chiama il LLM con system prompt + history dei messaggi."""
160
  sys_msg = SystemMessage(content=SYSTEM_PROMPT)
161
- response = self.llm_with_tools.invoke([sys_msg] + state["messages"])
162
- return {"messages": [response]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  def __call__(self, question: str) -> str:
165
  print(f"Agent received question (first 50 chars): {question[:50]}...")
 
7
  import math
8
  import json
9
  import unicodedata
10
+ from typing import TypedDict, Annotated, Any, List, Optional
11
 
12
+ from huggingface_hub import InferenceClient
13
+
14
+ from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, AIMessage, ToolMessage
15
  from langchain_core.tools import tool
16
  from langchain_community.tools import DuckDuckGoSearchRun
17
  from langchain_community.utilities import WikipediaAPIWrapper
 
90
  ct = response.headers.get("Content-Type", "")
91
  if "text" in ct or "json" in ct:
92
  return response.text[:5000]
93
+ return f"[Binary file - content-type: {ct}]"
94
  return f"No file found for task {task_id}"
95
  except Exception as e:
96
  return f"Error fetching task file: {e}"
97
+
98
 
99
 
100
  class AgentState(TypedDict):
 
119
  FINAL ANSWER: <your answer here>
120
  """
121
 
122
+ def _tool_to_openai_schema(t) -> dict:
123
+ """Converte un LangChain tool nel formato tool OpenAI."""
124
+ return {
125
+ "type": "function",
126
+ "function": {
127
+ "name": t.name,
128
+ "description": t.description,
129
+ "parameters": t.args_schema.schema() if t.args_schema else {"type": "object", "properties": {}},
130
+ }
131
+ }
132
 
133
  # --- Basic Agent Definition ---
134
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
135
  class BasicAgent:
136
  def __init__(self):
137
+ print("Initializing agent with HF InferenceClient...")
138
 
139
+ self.tools_list = [
140
  web_search,
141
  wikipedia_search,
142
  python_repl,
 
144
  get_task_file,
145
  ]
146
 
147
+ # Mappa nome funzione tool per esecuzione
148
+ self.tools_by_name = {t.name: t for t in self.tools_list}
149
+
150
+ # InferenceClient diretto — usa la Serverless Inference API HF
151
+ self.client = InferenceClient(
152
+ model="Qwen/Qwen2.5-72B-Instruct",
153
+ token=os.getenv("HF_TOKEN"),
 
154
  )
 
 
155
 
156
+ # Schema OpenAI dei tool per passarli al client
157
+ self.tools_schema = [_tool_to_openai_schema(t) for t in self.tools_list]
158
+
159
+ # Grafo LangGraph per gestire il loop ReAct
160
  builder = StateGraph(AgentState)
161
  builder.add_node("assistant", self._assistant_node)
162
+ builder.add_node("tools", ToolNode(self.tools_list))
163
  builder.add_edge(START, "assistant")
164
  builder.add_conditional_edges("assistant", tools_condition)
165
  builder.add_edge("tools", "assistant")
 
167
 
168
  print("Agent ready.")
169
 
170
+ def _messages_to_hf_format(self, messages: list) -> list:
171
+ """Converte messaggi LangChain nel formato dict che InferenceClient si aspetta."""
172
+ result = []
173
+ for m in messages:
174
+ if isinstance(m, SystemMessage):
175
+ result.append({"role": "system", "content": m.content})
176
+ elif isinstance(m, HumanMessage):
177
+ result.append({"role": "user", "content": m.content})
178
+ elif isinstance(m, AIMessage):
179
+ msg = {"role": "assistant", "content": m.content or ""}
180
+ # Includi tool_calls se presenti
181
+ if m.tool_calls:
182
+ msg["tool_calls"] = [
183
+ {
184
+ "id": tc["id"],
185
+ "type": "function",
186
+ "function": {
187
+ "name": tc["name"],
188
+ "arguments": json.dumps(tc["args"]),
189
+ }
190
+ }
191
+ for tc in m.tool_calls
192
+ ]
193
+ result.append(msg)
194
+ elif isinstance(m, ToolMessage):
195
+ result.append({
196
+ "role": "tool",
197
+ "tool_call_id": m.tool_call_id,
198
+ "content": m.content,
199
+ })
200
+ return result
201
+
202
  def _assistant_node(self, state: AgentState):
203
+ """Nodo assistant: chiama InferenceClient con i tool e restituisce la risposta."""
204
  sys_msg = SystemMessage(content=SYSTEM_PROMPT)
205
+ hf_messages = self._messages_to_hf_format([sys_msg] + state["messages"])
206
+
207
+ response = self.client.chat_completion(
208
+ messages=hf_messages,
209
+ tools=self.tools_schema,
210
+ tool_choice="auto",
211
+ max_tokens=1024,
212
+ temperature=0.1,
213
+ )
214
+
215
+ choice = response.choices[0].message
216
+
217
+ # Costruisci AIMessage compatibile con LangGraph
218
+ tool_calls = []
219
+ if choice.tool_calls:
220
+ for tc in choice.tool_calls:
221
+ tool_calls.append({
222
+ "id": tc.id,
223
+ "name": tc.function.name,
224
+ "args": json.loads(tc.function.arguments),
225
+ "type": "tool_call",
226
+ })
227
+
228
+ ai_message = AIMessage(
229
+ content=choice.content or "",
230
+ tool_calls=tool_calls,
231
+ )
232
+ return {"messages": [ai_message]}
233
 
234
  def __call__(self, question: str) -> str:
235
  print(f"Agent received question (first 50 chars): {question[:50]}...")