Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -20,44 +20,44 @@ from prompt import system_prompt
|
|
| 20 |
# --------------------------------------------------------------------
|
| 21 |
# 1. API Key Rotation Setup
|
| 22 |
# --------------------------------------------------------------------
|
| 23 |
-
api_keys = [
|
| 24 |
os.getenv("OPENROUTER_API_KEY"),
|
| 25 |
os.getenv("OPENROUTER_API_KEY_1")
|
| 26 |
-
]
|
| 27 |
|
| 28 |
-
if not
|
| 29 |
raise EnvironmentError("No OpenRouter API keys found in environment variables.")
|
| 30 |
|
| 31 |
-
api_key_cycle = itertools.cycle(
|
| 32 |
|
| 33 |
-
def get_next_api_key():
|
| 34 |
"""Get the next API key in rotation."""
|
| 35 |
return next(api_key_cycle)
|
| 36 |
|
| 37 |
class RotatingChatOpenAI(ChatOpenAI):
|
| 38 |
-
"""ChatOpenAI wrapper that automatically rotates API keys on failure."""
|
| 39 |
|
| 40 |
def invoke(self, *args, **kwargs):
|
| 41 |
-
#
|
| 42 |
-
|
| 43 |
-
self.
|
| 44 |
try:
|
| 45 |
return super().invoke(*args, **kwargs)
|
| 46 |
except Exception as e:
|
| 47 |
# Handle rate-limits or auth errors
|
| 48 |
if any(code in str(e) for code in ["429", "401", "403"]):
|
| 49 |
-
print(f"[API Key Rotation] Key {
|
| 50 |
continue
|
| 51 |
-
raise # Re-raise
|
| 52 |
-
raise RuntimeError("All OpenRouter API keys failed or rate-limited.")
|
| 53 |
|
| 54 |
# --------------------------------------------------------------------
|
| 55 |
# 2. Initialize LLM with API Key Rotation
|
| 56 |
# --------------------------------------------------------------------
|
| 57 |
llm = RotatingChatOpenAI(
|
| 58 |
base_url="https://openrouter.ai/api/v1",
|
| 59 |
-
|
| 60 |
-
model="qwen/qwen3-coder:free",
|
| 61 |
temperature=1
|
| 62 |
)
|
| 63 |
|
|
@@ -99,18 +99,31 @@ def llm_call(state: MessagesState):
|
|
| 99 |
# Tool Node
|
| 100 |
def tool_node(state: MessagesState):
|
| 101 |
"""Executes tools requested by the LLM."""
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
# Conditional Routing
|
| 110 |
def should_continue(state: MessagesState) -> Literal["Action", END]:
|
| 111 |
"""Route to tools if LLM made a tool call, else end."""
|
| 112 |
last_message = state["messages"][-1]
|
| 113 |
-
return "Action" if last_message
|
| 114 |
|
| 115 |
# --------------------------------------------------------------------
|
| 116 |
# 5. Build LangGraph Agent
|
|
@@ -154,8 +167,8 @@ class LangGraphAgent:
|
|
| 154 |
result = gaia_agent.invoke(input_state, config)
|
| 155 |
final_response = result["messages"][-1].content
|
| 156 |
|
| 157 |
-
|
|
|
|
| 158 |
return final_response.split("FINAL ANSWER:")[-1].strip()
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
return final_response
|
|
|
|
| 20 |
# --------------------------------------------------------------------
|
| 21 |
# 1. API Key Rotation Setup
|
| 22 |
# --------------------------------------------------------------------
|
| 23 |
+
api_keys = [k for k in [
|
| 24 |
os.getenv("OPENROUTER_API_KEY"),
|
| 25 |
os.getenv("OPENROUTER_API_KEY_1")
|
| 26 |
+
] if k]
|
| 27 |
|
| 28 |
+
if not api_keys:
|
| 29 |
raise EnvironmentError("No OpenRouter API keys found in environment variables.")
|
| 30 |
|
| 31 |
+
api_key_cycle = itertools.cycle(api_keys)
|
| 32 |
|
| 33 |
+
def get_next_api_key() -> str:
|
| 34 |
"""Get the next API key in rotation."""
|
| 35 |
return next(api_key_cycle)
|
| 36 |
|
| 37 |
class RotatingChatOpenAI(ChatOpenAI):
|
| 38 |
+
"""ChatOpenAI wrapper that automatically rotates OpenRouter API keys on failure."""
|
| 39 |
|
| 40 |
def invoke(self, *args, **kwargs):
|
| 41 |
+
for _ in range(len(api_keys)): # try each key once per call
|
| 42 |
+
current_key = get_next_api_key()
|
| 43 |
+
self.openai_api_key = current_key # ✅ Correct for ChatOpenAI
|
| 44 |
try:
|
| 45 |
return super().invoke(*args, **kwargs)
|
| 46 |
except Exception as e:
|
| 47 |
# Handle rate-limits or auth errors
|
| 48 |
if any(code in str(e) for code in ["429", "401", "403"]):
|
| 49 |
+
print(f"[API Key Rotation] Key {current_key[:5]}... failed, trying next key...")
|
| 50 |
continue
|
| 51 |
+
raise # Re-raise unexpected errors
|
| 52 |
+
raise RuntimeError("All OpenRouter API keys failed or were rate-limited.")
|
| 53 |
|
| 54 |
# --------------------------------------------------------------------
|
| 55 |
# 2. Initialize LLM with API Key Rotation
|
| 56 |
# --------------------------------------------------------------------
|
| 57 |
llm = RotatingChatOpenAI(
|
| 58 |
base_url="https://openrouter.ai/api/v1",
|
| 59 |
+
openai_api_key=get_next_api_key(), # ✅ start with the first key
|
| 60 |
+
model="qwen/qwen3-coder:free", # must support tool/function calling
|
| 61 |
temperature=1
|
| 62 |
)
|
| 63 |
|
|
|
|
| 99 |
# Tool Node
|
| 100 |
def tool_node(state: MessagesState):
|
| 101 |
"""Executes tools requested by the LLM."""
|
| 102 |
+
results = []
|
| 103 |
+
last_message = state["messages"][-1]
|
| 104 |
+
|
| 105 |
+
for tool_call in getattr(last_message, "tool_calls", []) or []:
|
| 106 |
+
tool = tools_by_name.get(tool_call["name"])
|
| 107 |
+
if not tool:
|
| 108 |
+
results.append(ToolMessage(content=f"Unknown tool: {tool_call['name']}", tool_call_id=tool_call["id"]))
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
args = tool_call["args"]
|
| 112 |
+
# Handle dict vs positional args safely
|
| 113 |
+
try:
|
| 114 |
+
observation = tool.invoke(**args) if isinstance(args, dict) else tool.invoke(args)
|
| 115 |
+
except Exception as e:
|
| 116 |
+
observation = f"[Tool Error] {str(e)}"
|
| 117 |
+
|
| 118 |
+
results.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
|
| 119 |
+
|
| 120 |
+
return {"messages": results}
|
| 121 |
|
| 122 |
# Conditional Routing
|
| 123 |
def should_continue(state: MessagesState) -> Literal["Action", END]:
|
| 124 |
"""Route to tools if LLM made a tool call, else end."""
|
| 125 |
last_message = state["messages"][-1]
|
| 126 |
+
return "Action" if getattr(last_message, "tool_calls", None) else END
|
| 127 |
|
| 128 |
# --------------------------------------------------------------------
|
| 129 |
# 5. Build LangGraph Agent
|
|
|
|
| 167 |
result = gaia_agent.invoke(input_state, config)
|
| 168 |
final_response = result["messages"][-1].content
|
| 169 |
|
| 170 |
+
# Extract "FINAL ANSWER" if present
|
| 171 |
+
if isinstance(final_response, str):
|
| 172 |
return final_response.split("FINAL ANSWER:")[-1].strip()
|
| 173 |
+
else:
|
| 174 |
+
return str(final_response)
|
|
|