Spaces:
Sleeping
Sleeping
Vela
commited on
Commit
·
540db73
1
Parent(s):
75115cd
modified functions
Browse files- application/agents/scraper_agent.py +8 -32
- application/services/gemini_api_service.py +1 -7
- main.py +12 -12
application/agents/scraper_agent.py
CHANGED
|
@@ -42,57 +42,42 @@ model_with_tools = model.bind_tools(tools)
|
|
| 42 |
def invoke_model(state: AgentState) -> dict:
|
| 43 |
"""Invokes the LLM with the current conversation history."""
|
| 44 |
logger.info("--- Invoking Model ---")
|
| 45 |
-
# LangGraph automatically passes the entire state
|
| 46 |
-
# The model_with_tools expects a list of BaseMessages
|
| 47 |
response = model_with_tools.invoke(state['messages'])
|
| 48 |
-
|
| 49 |
-
# We return a dictionary with the key corresponding to the state field name
|
| 50 |
-
return {"messages": [response]} # The response is already an AIMessage
|
| 51 |
|
| 52 |
def invoke_tools(state: AgentState) -> dict:
|
| 53 |
"""Invokes the necessary tools based on the last AI message."""
|
| 54 |
logger.info("--- Invoking Tools ---")
|
| 55 |
-
# The state contains the history, the last message is the AI's request
|
| 56 |
last_message = state['messages'][-1]
|
| 57 |
|
| 58 |
-
# Check if the last message is an AIMessage with tool_calls
|
| 59 |
if not hasattr(last_message, 'tool_calls') or not last_message.tool_calls:
|
| 60 |
logger.info("No tool calls found in the last message.")
|
| 61 |
-
# This scenario might indicate the conversation should end or requires clarification
|
| 62 |
-
# For now, return an empty dict, which won't update the state significantly.
|
| 63 |
-
# Consider adding a message indicating no tools were called if needed.
|
| 64 |
return {}
|
| 65 |
-
# Alternative: return {"messages": [SystemMessage(content="No tool calls requested.")]}
|
| 66 |
|
| 67 |
tool_invocation_messages = []
|
| 68 |
|
| 69 |
-
# Find the tool object by name
|
| 70 |
tool_map = {tool.name: tool for tool in tools}
|
| 71 |
|
| 72 |
for tool_call in last_message.tool_calls:
|
| 73 |
tool_name = tool_call['name']
|
| 74 |
tool_args = tool_call['args']
|
| 75 |
-
tool_call_id = tool_call['id']
|
| 76 |
|
| 77 |
logger.info(f"Executing tool: {tool_name} with args: {tool_args}")
|
| 78 |
|
| 79 |
if tool_name in tool_map:
|
| 80 |
selected_tool = tool_map[tool_name]
|
| 81 |
try:
|
| 82 |
-
# Use the tool's invoke method, passing the arguments dictionary
|
| 83 |
result = selected_tool.invoke(tool_args)
|
| 84 |
|
| 85 |
-
# IMPORTANT: Convert the result to a string or a JSON serializable format
|
| 86 |
-
# if it's a complex object. ToolMessage content should be simple.
|
| 87 |
-
# Adjust this based on what your tools actually return.
|
| 88 |
if isinstance(result, list) or isinstance(result, dict):
|
| 89 |
-
result_content = json.dumps(result)
|
| 90 |
-
elif hasattr(result, 'companies') and isinstance(result.companies, list):
|
| 91 |
result_content = f"Companies found: {', '.join(result.companies)}"
|
| 92 |
elif result is None:
|
| 93 |
result_content = "Tool executed successfully, but returned no specific data (None)."
|
| 94 |
else:
|
| 95 |
-
result_content = str(result)
|
| 96 |
|
| 97 |
logger.info(f"Tool {tool_name} result: {result_content}")
|
| 98 |
tool_invocation_messages.append(
|
|
@@ -100,7 +85,6 @@ def invoke_tools(state: AgentState) -> dict:
|
|
| 100 |
)
|
| 101 |
except Exception as e:
|
| 102 |
logger.error(f"Error executing tool {tool_name}: {e}")
|
| 103 |
-
# Return an error message in the ToolMessage
|
| 104 |
tool_invocation_messages.append(
|
| 105 |
ToolMessage(content=f"Error executing tool {tool_name}: {str(e)}", tool_call_id=tool_call_id)
|
| 106 |
)
|
|
@@ -110,29 +94,22 @@ def invoke_tools(state: AgentState) -> dict:
|
|
| 110 |
ToolMessage(content=f"Error: Tool '{tool_name}' not found.", tool_call_id=tool_call_id)
|
| 111 |
)
|
| 112 |
|
| 113 |
-
# Return the collected ToolMessages to be added to the state
|
| 114 |
return {"messages": tool_invocation_messages}
|
| 115 |
|
| 116 |
-
# --- Graph Definition ---
|
| 117 |
graph_builder = StateGraph(AgentState)
|
| 118 |
|
| 119 |
-
# Add nodes
|
| 120 |
graph_builder.add_node("scraper_agent", invoke_model)
|
| 121 |
-
graph_builder.add_node("tools", invoke_tools)
|
| 122 |
|
| 123 |
-
# Define edges
|
| 124 |
graph_builder.set_entry_point("scraper_agent")
|
| 125 |
|
| 126 |
-
# Conditional edge: After the agent runs, decide whether to call tools or end.
|
| 127 |
def router(state: AgentState) -> str:
|
| 128 |
"""Determines the next step based on the last message."""
|
| 129 |
last_message = state['messages'][-1]
|
| 130 |
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
|
| 131 |
-
# If the AI message has tool calls, invoke the tools node
|
| 132 |
logger.info("--- Routing to Tools ---")
|
| 133 |
return "tools"
|
| 134 |
else:
|
| 135 |
-
# Otherwise, the conversation can end
|
| 136 |
logger.info("--- Routing to End ---")
|
| 137 |
return END
|
| 138 |
|
|
@@ -140,12 +117,11 @@ graph_builder.add_conditional_edges(
|
|
| 140 |
"scraper_agent",
|
| 141 |
router,
|
| 142 |
{
|
| 143 |
-
"tools": "tools",
|
| 144 |
-
END: END,
|
| 145 |
}
|
| 146 |
)
|
| 147 |
|
| 148 |
-
# After tools are invoked, their results (ToolMessages) should go back to the agent
|
| 149 |
graph_builder.add_edge("tools", "scraper_agent")
|
| 150 |
|
| 151 |
# Compile the graph
|
|
|
|
| 42 |
def invoke_model(state: AgentState) -> dict:
|
| 43 |
"""Invokes the LLM with the current conversation history."""
|
| 44 |
logger.info("--- Invoking Model ---")
|
|
|
|
|
|
|
| 45 |
response = model_with_tools.invoke(state['messages'])
|
| 46 |
+
return {"messages": [response]}
|
|
|
|
|
|
|
| 47 |
|
| 48 |
def invoke_tools(state: AgentState) -> dict:
|
| 49 |
"""Invokes the necessary tools based on the last AI message."""
|
| 50 |
logger.info("--- Invoking Tools ---")
|
|
|
|
| 51 |
last_message = state['messages'][-1]
|
| 52 |
|
|
|
|
| 53 |
if not hasattr(last_message, 'tool_calls') or not last_message.tool_calls:
|
| 54 |
logger.info("No tool calls found in the last message.")
|
|
|
|
|
|
|
|
|
|
| 55 |
return {}
|
|
|
|
| 56 |
|
| 57 |
tool_invocation_messages = []
|
| 58 |
|
|
|
|
| 59 |
tool_map = {tool.name: tool for tool in tools}
|
| 60 |
|
| 61 |
for tool_call in last_message.tool_calls:
|
| 62 |
tool_name = tool_call['name']
|
| 63 |
tool_args = tool_call['args']
|
| 64 |
+
tool_call_id = tool_call['id']
|
| 65 |
|
| 66 |
logger.info(f"Executing tool: {tool_name} with args: {tool_args}")
|
| 67 |
|
| 68 |
if tool_name in tool_map:
|
| 69 |
selected_tool = tool_map[tool_name]
|
| 70 |
try:
|
|
|
|
| 71 |
result = selected_tool.invoke(tool_args)
|
| 72 |
|
|
|
|
|
|
|
|
|
|
| 73 |
if isinstance(result, list) or isinstance(result, dict):
|
| 74 |
+
result_content = json.dumps(result)
|
| 75 |
+
elif hasattr(result, 'companies') and isinstance(result.companies, list):
|
| 76 |
result_content = f"Companies found: {', '.join(result.companies)}"
|
| 77 |
elif result is None:
|
| 78 |
result_content = "Tool executed successfully, but returned no specific data (None)."
|
| 79 |
else:
|
| 80 |
+
result_content = str(result)
|
| 81 |
|
| 82 |
logger.info(f"Tool {tool_name} result: {result_content}")
|
| 83 |
tool_invocation_messages.append(
|
|
|
|
| 85 |
)
|
| 86 |
except Exception as e:
|
| 87 |
logger.error(f"Error executing tool {tool_name}: {e}")
|
|
|
|
| 88 |
tool_invocation_messages.append(
|
| 89 |
ToolMessage(content=f"Error executing tool {tool_name}: {str(e)}", tool_call_id=tool_call_id)
|
| 90 |
)
|
|
|
|
| 94 |
ToolMessage(content=f"Error: Tool '{tool_name}' not found.", tool_call_id=tool_call_id)
|
| 95 |
)
|
| 96 |
|
|
|
|
| 97 |
return {"messages": tool_invocation_messages}
|
| 98 |
|
|
|
|
| 99 |
graph_builder = StateGraph(AgentState)
|
| 100 |
|
|
|
|
| 101 |
graph_builder.add_node("scraper_agent", invoke_model)
|
| 102 |
+
graph_builder.add_node("tools", invoke_tools)
|
| 103 |
|
|
|
|
| 104 |
graph_builder.set_entry_point("scraper_agent")
|
| 105 |
|
|
|
|
| 106 |
def router(state: AgentState) -> str:
|
| 107 |
"""Determines the next step based on the last message."""
|
| 108 |
last_message = state['messages'][-1]
|
| 109 |
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
|
|
|
|
| 110 |
logger.info("--- Routing to Tools ---")
|
| 111 |
return "tools"
|
| 112 |
else:
|
|
|
|
| 113 |
logger.info("--- Routing to End ---")
|
| 114 |
return END
|
| 115 |
|
|
|
|
| 117 |
"scraper_agent",
|
| 118 |
router,
|
| 119 |
{
|
| 120 |
+
"tools": "tools",
|
| 121 |
+
END: END,
|
| 122 |
}
|
| 123 |
)
|
| 124 |
|
|
|
|
| 125 |
graph_builder.add_edge("tools", "scraper_agent")
|
| 126 |
|
| 127 |
# Compile the graph
|
application/services/gemini_api_service.py
CHANGED
|
@@ -152,13 +152,11 @@ def upload_file(
|
|
| 152 |
Exception: If upload fails.
|
| 153 |
"""
|
| 154 |
try:
|
| 155 |
-
# Determine if input is a URL
|
| 156 |
is_url = isinstance(file, str) and file.startswith(('http://', 'https://'))
|
| 157 |
|
| 158 |
-
# Determine file name if not provided
|
| 159 |
if not file_name:
|
| 160 |
if is_url:
|
| 161 |
-
file_name = os.path.basename(file.split("?")[0])
|
| 162 |
elif isinstance(file, str):
|
| 163 |
file_name = os.path.basename(file)
|
| 164 |
elif hasattr(file, "name"):
|
|
@@ -172,14 +170,12 @@ def upload_file(
|
|
| 172 |
config.update({"name": sanitized_name, "mime_type": mime_type})
|
| 173 |
gemini_file_key = f"files/{sanitized_name}"
|
| 174 |
|
| 175 |
-
# Check if file already exists
|
| 176 |
if gemini_file_key in get_files():
|
| 177 |
logger.info(f"File already exists on Gemini: {gemini_file_key}")
|
| 178 |
return client.files.get(name=gemini_file_key)
|
| 179 |
|
| 180 |
logger.info(f"Uploading file to Gemini: {gemini_file_key}")
|
| 181 |
|
| 182 |
-
# Handle URL
|
| 183 |
if is_url:
|
| 184 |
headers = {
|
| 185 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
|
@@ -189,14 +185,12 @@ def upload_file(
|
|
| 189 |
file_content = io.BytesIO(response.content)
|
| 190 |
return client.files.upload(file=file_content, config=config)
|
| 191 |
|
| 192 |
-
# Handle local file path
|
| 193 |
if isinstance(file, str):
|
| 194 |
if not os.path.isfile(file):
|
| 195 |
raise FileNotFoundError(f"Local file '{file}' does not exist.")
|
| 196 |
with open(file, "rb") as f:
|
| 197 |
return client.files.upload(file=f, config=config)
|
| 198 |
|
| 199 |
-
# Handle already opened binary file object
|
| 200 |
return client.files.upload(file=file, config=config)
|
| 201 |
|
| 202 |
except Exception as e:
|
|
|
|
| 152 |
Exception: If upload fails.
|
| 153 |
"""
|
| 154 |
try:
|
|
|
|
| 155 |
is_url = isinstance(file, str) and file.startswith(('http://', 'https://'))
|
| 156 |
|
|
|
|
| 157 |
if not file_name:
|
| 158 |
if is_url:
|
| 159 |
+
file_name = os.path.basename(file.split("?")[0])
|
| 160 |
elif isinstance(file, str):
|
| 161 |
file_name = os.path.basename(file)
|
| 162 |
elif hasattr(file, "name"):
|
|
|
|
| 170 |
config.update({"name": sanitized_name, "mime_type": mime_type})
|
| 171 |
gemini_file_key = f"files/{sanitized_name}"
|
| 172 |
|
|
|
|
| 173 |
if gemini_file_key in get_files():
|
| 174 |
logger.info(f"File already exists on Gemini: {gemini_file_key}")
|
| 175 |
return client.files.get(name=gemini_file_key)
|
| 176 |
|
| 177 |
logger.info(f"Uploading file to Gemini: {gemini_file_key}")
|
| 178 |
|
|
|
|
| 179 |
if is_url:
|
| 180 |
headers = {
|
| 181 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
|
|
|
| 185 |
file_content = io.BytesIO(response.content)
|
| 186 |
return client.files.upload(file=file_content, config=config)
|
| 187 |
|
|
|
|
| 188 |
if isinstance(file, str):
|
| 189 |
if not os.path.isfile(file):
|
| 190 |
raise FileNotFoundError(f"Local file '{file}' does not exist.")
|
| 191 |
with open(file, "rb") as f:
|
| 192 |
return client.files.upload(file=f, config=config)
|
| 193 |
|
|
|
|
| 194 |
return client.files.upload(file=file, config=config)
|
| 195 |
|
| 196 |
except Exception as e:
|
main.py
CHANGED
|
@@ -147,15 +147,15 @@ workflow.set_entry_point("supervisor")
|
|
| 147 |
graph = workflow.compile()
|
| 148 |
|
| 149 |
# # === Example Run ===
|
| 150 |
-
if __name__ == "__main__":
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
|
|
|
| 147 |
graph = workflow.compile()
|
| 148 |
|
| 149 |
# # === Example Run ===
|
| 150 |
+
# if __name__ == "__main__":
|
| 151 |
+
# logger.info("Starting the graph execution...")
|
| 152 |
+
# initial_message = HumanMessage(content="Can you get zalando pdf link")
|
| 153 |
+
# input_state = {"messages": [initial_message]}
|
| 154 |
+
|
| 155 |
+
# for step in graph.stream(input_state):
|
| 156 |
+
# if "__end__" not in step:
|
| 157 |
+
# logger.info(f"Graph Step Output: {step}")
|
| 158 |
+
# print(step)
|
| 159 |
+
# print("----")
|
| 160 |
+
|
| 161 |
+
# logger.info("Graph execution completed.")
|