Spaces:
Runtime error
Runtime error
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -152,31 +152,14 @@ SYSTEM_PROMPT = SystemMessage(
|
|
| 152 |
# ---------------------------------------------------------------------
|
| 153 |
# 7) LangGraph – Planner + Tools + Router
|
| 154 |
# ---------------------------------------------------------------------
|
| 155 |
-
def extract_tool_call(text: str) -> tuple[str, str] | None:
|
| 156 |
-
"""Parse Gemini output like: 'Tool: xyz\nInput: abc'."""
|
| 157 |
-
match = re.search(r"Tool:\s*(\w+)\s*Input:\s*(.+)", text, re.DOTALL)
|
| 158 |
-
if match:
|
| 159 |
-
return match.group(1).strip(), match.group(2).strip()
|
| 160 |
-
return None
|
| 161 |
|
| 162 |
def planner(state: MessagesState):
|
| 163 |
msgs = state["messages"]
|
| 164 |
if msgs[0].type != "system":
|
| 165 |
msgs = [SYSTEM_PROMPT] + msgs
|
| 166 |
-
|
| 167 |
resp = with_backoff(lambda: gemini_llm.invoke(msgs))
|
| 168 |
content = resp.content.strip()
|
| 169 |
-
|
| 170 |
-
parsed = extract_tool_call(content)
|
| 171 |
-
if parsed:
|
| 172 |
-
tool_name, tool_input = parsed
|
| 173 |
-
tool = {t.name: t for t in TOOLS}.get(tool_name)
|
| 174 |
-
if tool:
|
| 175 |
-
result = tool.invoke(tool_input)
|
| 176 |
-
new_msg = HumanMessage(content=f"Tool result:\n{result}")
|
| 177 |
-
return {"messages": msgs + [resp, new_msg], "should_end": False}
|
| 178 |
-
|
| 179 |
-
finished = "\n" not in content # einfache Heuristik
|
| 180 |
return {"messages": msgs + [resp], "should_end": finished}
|
| 181 |
|
| 182 |
def route(state):
|
|
|
|
| 152 |
# ---------------------------------------------------------------------
|
| 153 |
# 7) LangGraph – Planner + Tools + Router
|
| 154 |
# ---------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
def planner(state: MessagesState):
|
| 157 |
msgs = state["messages"]
|
| 158 |
if msgs[0].type != "system":
|
| 159 |
msgs = [SYSTEM_PROMPT] + msgs
|
|
|
|
| 160 |
resp = with_backoff(lambda: gemini_llm.invoke(msgs))
|
| 161 |
content = resp.content.strip()
|
| 162 |
+
finished = not getattr(resp, "tool_calls", None) and "\n" not in content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
return {"messages": msgs + [resp], "should_end": finished}
|
| 164 |
|
| 165 |
def route(state):
|