Spaces:
Sleeping
Sleeping
Implement LangGraphAgent and enhance BasicAgent functionality; update requirements and add .env configuration
Browse files
.env
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OPENROUTER_API_KEY=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
|
| 2 |
+
OPENROUTER_API_URL=https://openrouter.ai/api/v1
|
agent.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
from langgraph.graph import StateGraph, START, END, MessagesState
|
| 5 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 6 |
+
from langgraph.prebuilt import ToolNode # add
|
| 7 |
+
|
| 8 |
+
# Try to import tools from tools.py
|
| 9 |
+
try:
|
| 10 |
+
from .tools import get_tools as _get_tools # package-style
|
| 11 |
+
except Exception:
|
| 12 |
+
try:
|
| 13 |
+
from tools import get_tools as _get_tools # script-style
|
| 14 |
+
except Exception:
|
| 15 |
+
def _get_tools(): return [] # fallback
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
# Optional, used when OPENAI_API_KEY is available
|
| 19 |
+
from langchain_openai import ChatOpenAI
|
| 20 |
+
except Exception: # pragma: no cover - optional dependency resolution
|
| 21 |
+
ChatOpenAI = None # type: ignore
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class _EchoModel:
|
| 25 |
+
"""Simple stub model used when no API key / model is configured."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, prefix: str = "[stub]"):
|
| 28 |
+
self.prefix = prefix
|
| 29 |
+
|
| 30 |
+
def invoke(self, messages):
|
| 31 |
+
last = messages[-1]
|
| 32 |
+
content = getattr(last, "content", str(last))
|
| 33 |
+
# Ensure the contract: always emit FINAL ANSWER:
|
| 34 |
+
return AIMessage(content=f"{self.prefix} FINAL ANSWER: You asked: {content}")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class LangGraphAgent:
|
| 38 |
+
"""
|
| 39 |
+
Minimal LangGraph agent template.
|
| 40 |
+
|
| 41 |
+
Usage:
|
| 42 |
+
agent = LangGraphAgent()
|
| 43 |
+
answer = agent("What is the capital of France?")
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, *, model: Optional[object] = None, system_prompt: Optional[str] = None):
|
| 47 |
+
# Guide the model to use tools and to output a clear final answer.
|
| 48 |
+
base_prompt = system_prompt or "You are a helpful assistant. Keep answers concise."
|
| 49 |
+
self.system_prompt = (
|
| 50 |
+
base_prompt
|
| 51 |
+
+ "\n\nGuidelines:\n"
|
| 52 |
+
"- Use tools when they can verify facts or fetch fresh data.\n"
|
| 53 |
+
"- Think privately; do not reveal chain-of-thought.\n"
|
| 54 |
+
"- Provide the final user-facing result prefixed exactly with 'FINAL ANSWER:'."
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Choose an LLM if not provided
|
| 58 |
+
if model is None:
|
| 59 |
+
if ChatOpenAI is not None:
|
| 60 |
+
model = ChatOpenAI(
|
| 61 |
+
api_key=os.getenv("OPENROUTER_API_KEY"),
|
| 62 |
+
base_url=os.getenv("OPENROUTER_BASE_URL"),
|
| 63 |
+
model="openai/gpt-oss-20b:free",
|
| 64 |
+
)
|
| 65 |
+
else:
|
| 66 |
+
model = _EchoModel()
|
| 67 |
+
self.model = model
|
| 68 |
+
|
| 69 |
+
# Load tools and bind to the model if supported
|
| 70 |
+
self.tools = _get_tools()
|
| 71 |
+
self.llm = getattr(self.model, "bind_tools",
|
| 72 |
+
lambda _: self.model)(self.tools)
|
| 73 |
+
|
| 74 |
+
# Build a tool-using LangGraph: agent -> (maybe) tools -> agent -> ... -> END
|
| 75 |
+
def call_agent(state: MessagesState):
|
| 76 |
+
msgs = [SystemMessage(content=self.system_prompt)
|
| 77 |
+
] + list(state["messages"])
|
| 78 |
+
ai = self.llm.invoke(msgs)
|
| 79 |
+
return {"messages": [ai]}
|
| 80 |
+
|
| 81 |
+
def should_call_tools(state: MessagesState):
|
| 82 |
+
# If the last AI message includes tool calls, route to tools; else end.
|
| 83 |
+
last = state["messages"][-1]
|
| 84 |
+
if isinstance(last, AIMessage) and getattr(last, "tool_calls", None):
|
| 85 |
+
return "tools"
|
| 86 |
+
return "end"
|
| 87 |
+
|
| 88 |
+
builder = StateGraph(MessagesState)
|
| 89 |
+
builder.add_node("agent", call_agent)
|
| 90 |
+
builder.add_node("tools", ToolNode(self.tools))
|
| 91 |
+
builder.add_edge(START, "agent")
|
| 92 |
+
builder.add_edge("tools", "agent")
|
| 93 |
+
builder.add_conditional_edges("agent", should_call_tools, {
|
| 94 |
+
"tools": "tools", "end": END})
|
| 95 |
+
self.graph = builder.compile()
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def _extract_final_answer(text: str) -> str:
|
| 99 |
+
key = "FINAL ANSWER:"
|
| 100 |
+
idx = text.rfind(key)
|
| 101 |
+
return text[idx + len(key):].strip() if idx != -1 else text.strip()
|
| 102 |
+
|
| 103 |
+
def __call__(self, question: str) -> str:
|
| 104 |
+
state = {"messages": [HumanMessage(content=question)]}
|
| 105 |
+
result = self.graph.invoke(state)
|
| 106 |
+
messages = result.get("messages", [])
|
| 107 |
+
# Return only the content after "FINAL ANSWER:"
|
| 108 |
+
for msg in reversed(messages):
|
| 109 |
+
if isinstance(msg, AIMessage):
|
| 110 |
+
return self._extract_final_answer(msg.content)
|
| 111 |
+
return self._extract_final_answer(messages[-1].content) if messages else ""
|
app.py
CHANGED
|
@@ -10,25 +10,37 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
|
| 10 |
|
| 11 |
# --- Basic Agent Definition ---
|
| 12 |
# ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
|
|
|
|
|
|
|
| 13 |
class BasicAgent:
|
| 14 |
def __init__(self):
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def __call__(self, question: str) -> str:
|
| 17 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
def run_and_submit_all(
|
| 23 |
"""
|
| 24 |
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
| 25 |
and displays the results.
|
| 26 |
"""
|
| 27 |
# --- Determine HF Space Runtime URL and Repo URL ---
|
| 28 |
-
|
|
|
|
| 29 |
|
| 30 |
if profile:
|
| 31 |
-
username= f"{profile.username}"
|
| 32 |
print(f"User logged in: {username}")
|
| 33 |
else:
|
| 34 |
print("User not logged in.")
|
|
@@ -55,16 +67,16 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 55 |
response.raise_for_status()
|
| 56 |
questions_data = response.json()
|
| 57 |
if not questions_data:
|
| 58 |
-
|
| 59 |
-
|
| 60 |
print(f"Fetched {len(questions_data)} questions.")
|
| 61 |
except requests.exceptions.RequestException as e:
|
| 62 |
print(f"Error fetching questions: {e}")
|
| 63 |
return f"Error fetching questions: {e}", None
|
| 64 |
except requests.exceptions.JSONDecodeError as e:
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
except Exception as e:
|
| 69 |
print(f"An unexpected error occurred fetching questions: {e}")
|
| 70 |
return f"An unexpected error occurred fetching questions: {e}", None
|
|
@@ -81,18 +93,22 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 81 |
continue
|
| 82 |
try:
|
| 83 |
submitted_answer = agent(question_text)
|
| 84 |
-
answers_payload.append(
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
except Exception as e:
|
| 87 |
-
|
| 88 |
-
|
|
|
|
| 89 |
|
| 90 |
if not answers_payload:
|
| 91 |
print("Agent did not produce any answers to submit.")
|
| 92 |
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
| 93 |
|
| 94 |
-
# 4. Prepare Submission
|
| 95 |
-
submission_data = {"username": username.strip(
|
|
|
|
| 96 |
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
|
| 97 |
print(status_update)
|
| 98 |
|
|
@@ -162,9 +178,11 @@ with gr.Blocks() as demo:
|
|
| 162 |
|
| 163 |
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
| 164 |
|
| 165 |
-
status_output = gr.Textbox(
|
|
|
|
| 166 |
# Removed max_rows=10 from DataFrame constructor
|
| 167 |
-
results_table = gr.DataFrame(
|
|
|
|
| 168 |
|
| 169 |
run_button.click(
|
| 170 |
fn=run_and_submit_all,
|
|
@@ -175,22 +193,24 @@ if __name__ == "__main__":
|
|
| 175 |
print("\n" + "-"*30 + " App Starting " + "-"*30)
|
| 176 |
# Check for SPACE_HOST and SPACE_ID at startup for information
|
| 177 |
space_host_startup = os.getenv("SPACE_HOST")
|
| 178 |
-
space_id_startup = os.getenv("SPACE_ID")
|
| 179 |
|
| 180 |
if space_host_startup:
|
| 181 |
print(f"✅ SPACE_HOST found: {space_host_startup}")
|
| 182 |
-
print(
|
|
|
|
| 183 |
else:
|
| 184 |
print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
|
| 185 |
|
| 186 |
-
if space_id_startup:
|
| 187 |
print(f"✅ SPACE_ID found: {space_id_startup}")
|
| 188 |
print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
|
| 189 |
-
print(
|
|
|
|
| 190 |
else:
|
| 191 |
print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
|
| 192 |
|
| 193 |
print("-"*(60 + len(" App Starting ")) + "\n")
|
| 194 |
|
| 195 |
print("Launching Gradio Interface for Basic Agent Evaluation...")
|
| 196 |
-
demo.launch(debug=True, share=False)
|
|
|
|
| 10 |
|
| 11 |
# --- Basic Agent Definition ---
|
| 12 |
# ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
|
| 13 |
+
|
| 14 |
+
|
| 15 |
class BasicAgent:
|
| 16 |
def __init__(self):
|
| 17 |
+
from agent import LangGraphAgent
|
| 18 |
+
print("BasicAgent initialized (LangGraph).")
|
| 19 |
+
# Create a minimal LangGraph agent; will use OPENAI_API_KEY if set, else a stub echo
|
| 20 |
+
self._agent = LangGraphAgent()
|
| 21 |
+
|
| 22 |
def __call__(self, question: str) -> str:
|
| 23 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
| 24 |
+
try:
|
| 25 |
+
answer = self._agent(question)
|
| 26 |
+
except Exception as e:
|
| 27 |
+
print(f"LangGraph agent error: {e}")
|
| 28 |
+
answer = "Sorry, the agent hit a snag."
|
| 29 |
+
print(f"Agent returning answer (first 80 chars): {answer[:80]}...")
|
| 30 |
+
return answer
|
| 31 |
+
|
| 32 |
|
| 33 |
+
def run_and_submit_all(profile: gr.OAuthProfile | None):
|
| 34 |
"""
|
| 35 |
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
| 36 |
and displays the results.
|
| 37 |
"""
|
| 38 |
# --- Determine HF Space Runtime URL and Repo URL ---
|
| 39 |
+
# Get the SPACE_ID for sending link to the code
|
| 40 |
+
space_id = os.getenv("SPACE_ID")
|
| 41 |
|
| 42 |
if profile:
|
| 43 |
+
username = f"{profile.username}"
|
| 44 |
print(f"User logged in: {username}")
|
| 45 |
else:
|
| 46 |
print("User not logged in.")
|
|
|
|
| 67 |
response.raise_for_status()
|
| 68 |
questions_data = response.json()
|
| 69 |
if not questions_data:
|
| 70 |
+
print("Fetched questions list is empty.")
|
| 71 |
+
return "Fetched questions list is empty or invalid format.", None
|
| 72 |
print(f"Fetched {len(questions_data)} questions.")
|
| 73 |
except requests.exceptions.RequestException as e:
|
| 74 |
print(f"Error fetching questions: {e}")
|
| 75 |
return f"Error fetching questions: {e}", None
|
| 76 |
except requests.exceptions.JSONDecodeError as e:
|
| 77 |
+
print(f"Error decoding JSON response from questions endpoint: {e}")
|
| 78 |
+
print(f"Response text: {response.text[:500]}")
|
| 79 |
+
return f"Error decoding server response for questions: {e}", None
|
| 80 |
except Exception as e:
|
| 81 |
print(f"An unexpected error occurred fetching questions: {e}")
|
| 82 |
return f"An unexpected error occurred fetching questions: {e}", None
|
|
|
|
| 93 |
continue
|
| 94 |
try:
|
| 95 |
submitted_answer = agent(question_text)
|
| 96 |
+
answers_payload.append(
|
| 97 |
+
{"task_id": task_id, "submitted_answer": submitted_answer})
|
| 98 |
+
results_log.append(
|
| 99 |
+
{"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
| 100 |
except Exception as e:
|
| 101 |
+
print(f"Error running agent on task {task_id}: {e}")
|
| 102 |
+
results_log.append(
|
| 103 |
+
{"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
|
| 104 |
|
| 105 |
if not answers_payload:
|
| 106 |
print("Agent did not produce any answers to submit.")
|
| 107 |
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
| 108 |
|
| 109 |
+
# 4. Prepare Submission
|
| 110 |
+
submission_data = {"username": username.strip(
|
| 111 |
+
), "agent_code": agent_code, "answers": answers_payload}
|
| 112 |
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
|
| 113 |
print(status_update)
|
| 114 |
|
|
|
|
| 178 |
|
| 179 |
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
| 180 |
|
| 181 |
+
status_output = gr.Textbox(
|
| 182 |
+
label="Run Status / Submission Result", lines=5, interactive=False)
|
| 183 |
# Removed max_rows=10 from DataFrame constructor
|
| 184 |
+
results_table = gr.DataFrame(
|
| 185 |
+
label="Questions and Agent Answers", wrap=True)
|
| 186 |
|
| 187 |
run_button.click(
|
| 188 |
fn=run_and_submit_all,
|
|
|
|
| 193 |
print("\n" + "-"*30 + " App Starting " + "-"*30)
|
| 194 |
# Check for SPACE_HOST and SPACE_ID at startup for information
|
| 195 |
space_host_startup = os.getenv("SPACE_HOST")
|
| 196 |
+
space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
|
| 197 |
|
| 198 |
if space_host_startup:
|
| 199 |
print(f"✅ SPACE_HOST found: {space_host_startup}")
|
| 200 |
+
print(
|
| 201 |
+
f" Runtime URL should be: https://{space_host_startup}.hf.space")
|
| 202 |
else:
|
| 203 |
print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
|
| 204 |
|
| 205 |
+
if space_id_startup: # Print repo URLs if SPACE_ID is found
|
| 206 |
print(f"✅ SPACE_ID found: {space_id_startup}")
|
| 207 |
print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
|
| 208 |
+
print(
|
| 209 |
+
f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
|
| 210 |
else:
|
| 211 |
print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
|
| 212 |
|
| 213 |
print("-"*(60 + len(" App Starting ")) + "\n")
|
| 214 |
|
| 215 |
print("Launching Gradio Interface for Basic Agent Evaluation...")
|
| 216 |
+
demo.launch(debug=True, share=False)
|
prompt.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GAIA system prompt
|
| 2 |
+
system_prompt = """\
|
| 3 |
+
You are a general AI assistant.
|
| 4 |
+
I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
|
| 5 |
+
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
|
| 6 |
+
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
|
| 7 |
+
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
|
| 8 |
+
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
| 9 |
+
"""
|
requirements.txt
CHANGED
|
@@ -1,2 +1,7 @@
|
|
| 1 |
gradio
|
| 2 |
-
requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
gradio
|
| 2 |
+
requests
|
| 3 |
+
pandas
|
| 4 |
+
langgraph
|
| 5 |
+
langchain-core
|
| 6 |
+
langchain-openai
|
| 7 |
+
openai
|
tools.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from langchain_community.tools import Tool, BraveSearch, YouTubeSearchTool, ExtractTextTool
|
| 3 |
+
from langchain_community.tools import DuckDuckGoSearchResults, GoogleSearchResults
|
| 4 |
+
from langchain_community.tools import WikipediaQueryRun
|
| 5 |
+
from langchain_community.tools import WolframAlphaQueryRun
|
| 6 |
+
from typing import Any, Dict, List, Optional
|
| 7 |
+
import json
|
| 8 |
+
import re
|
| 9 |
+
from datetime import datetime, timedelta
|
| 10 |
+
|
| 11 |
+
# Structured tools
|
| 12 |
+
try:
|
| 13 |
+
from langchain_core.tools import tool
|
| 14 |
+
except Exception:
|
| 15 |
+
def tool(*args, **kwargs):
|
| 16 |
+
def _wrap(fn): return fn
|
| 17 |
+
return _wrap
|
| 18 |
+
|
| 19 |
+
# Optional deps
|
| 20 |
+
try:
|
| 21 |
+
from youtube_transcript_api import (
|
| 22 |
+
YouTubeTranscriptApi,
|
| 23 |
+
TranscriptsDisabled,
|
| 24 |
+
NoTranscriptFound,
|
| 25 |
+
)
|
| 26 |
+
except Exception:
|
| 27 |
+
YouTubeTranscriptApi = None # type: ignore
|
| 28 |
+
TranscriptsDisabled = Exception # type: ignore
|
| 29 |
+
NoTranscriptFound = Exception # type: ignore
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from dateutil import parser as date_parser
|
| 33 |
+
from dateutil.relativedelta import relativedelta
|
| 34 |
+
except Exception:
|
| 35 |
+
date_parser = None # type: ignore
|
| 36 |
+
relativedelta = None # type: ignore
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
from zoneinfo import ZoneInfo # py>=3.9
|
| 40 |
+
except Exception:
|
| 41 |
+
ZoneInfo = None # type: ignore
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _parse_video_id(url_or_id: str) -> Optional[str]:
|
| 45 |
+
s = (url_or_id or "").strip()
|
| 46 |
+
if re.fullmatch(r"[0-9A-Za-z_-]{11}", s):
|
| 47 |
+
return s
|
| 48 |
+
try:
|
| 49 |
+
from urllib.parse import urlparse, parse_qs
|
| 50 |
+
u = urlparse(s)
|
| 51 |
+
if u.netloc.endswith(("youtube.com", "m.youtube.com", "music.youtube.com")):
|
| 52 |
+
qs = parse_qs(u.query)
|
| 53 |
+
v = (qs.get("v") or [""])[0]
|
| 54 |
+
if re.fullmatch(r"[0-9A-Za-z_-]{11}", v):
|
| 55 |
+
return v
|
| 56 |
+
if u.netloc.endswith("youtu.be"):
|
| 57 |
+
vid = u.path.lstrip("/").split("/")[0]
|
| 58 |
+
if re.fullmatch(r"[0-9A-Za-z_-]{11}", vid):
|
| 59 |
+
return vid
|
| 60 |
+
except Exception:
|
| 61 |
+
pass
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _to_dt(value: str, tz: Optional[str] = None) -> datetime:
|
| 66 |
+
if date_parser is not None:
|
| 67 |
+
dt = date_parser.parse(value)
|
| 68 |
+
else:
|
| 69 |
+
try:
|
| 70 |
+
dt = datetime.fromisoformat(value)
|
| 71 |
+
except Exception:
|
| 72 |
+
dt = datetime.strptime(value, "%Y-%m-%d")
|
| 73 |
+
if tz and ZoneInfo is not None:
|
| 74 |
+
try:
|
| 75 |
+
z = ZoneInfo(tz)
|
| 76 |
+
dt = dt.replace(
|
| 77 |
+
tzinfo=z) if dt.tzinfo is None else dt.astimezone(z)
|
| 78 |
+
except Exception:
|
| 79 |
+
pass
|
| 80 |
+
return dt
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@tool("youtube_transcript", return_direct=False)
|
| 84 |
+
def youtube_transcript(video: str, languages: Optional[List[str]] = None, max_chars: int = 8000) -> Dict[str, Any]:
|
| 85 |
+
"""
|
| 86 |
+
Get YouTube transcript for a video URL or ID.
|
| 87 |
+
Params:
|
| 88 |
+
- video: URL or 11-char video ID
|
| 89 |
+
- languages: preferred languages, e.g. ["vi","en"]
|
| 90 |
+
- max_chars: truncate long transcripts
|
| 91 |
+
"""
|
| 92 |
+
if YouTubeTranscriptApi is None:
|
| 93 |
+
return {"ok": False, "error": "youtube-transcript-api not installed. pip install youtube-transcript-api"}
|
| 94 |
+
vid = _parse_video_id(video)
|
| 95 |
+
if not vid:
|
| 96 |
+
return {"ok": False, "error": "Invalid YouTube video id/url."}
|
| 97 |
+
langs = languages or ["vi", "en"]
|
| 98 |
+
try:
|
| 99 |
+
segs = None
|
| 100 |
+
try:
|
| 101 |
+
segs = YouTubeTranscriptApi.get_transcript(vid, languages=langs)
|
| 102 |
+
except NoTranscriptFound:
|
| 103 |
+
try:
|
| 104 |
+
segs = YouTubeTranscriptApi.get_transcript(
|
| 105 |
+
vid, languages=["en"])
|
| 106 |
+
except Exception:
|
| 107 |
+
pass
|
| 108 |
+
if not segs:
|
| 109 |
+
try:
|
| 110 |
+
tx = YouTubeTranscriptApi.list_transcripts(vid)
|
| 111 |
+
for tr in tx:
|
| 112 |
+
if tr.is_translatable and "en" in langs:
|
| 113 |
+
segs = tr.translate("en").fetch()
|
| 114 |
+
break
|
| 115 |
+
except Exception:
|
| 116 |
+
pass
|
| 117 |
+
if not segs:
|
| 118 |
+
return {"ok": False, "error": "No transcript available."}
|
| 119 |
+
text = " ".join(s.get("text", "") for s in segs).strip()
|
| 120 |
+
if max_chars and len(text) > max_chars:
|
| 121 |
+
text = text[:max_chars] + " ...[truncated]..."
|
| 122 |
+
return {"ok": True, "data": {"video_id": vid, "text": text, "segments": segs}}
|
| 123 |
+
except TranscriptsDisabled:
|
| 124 |
+
return {"ok": False, "error": "Transcripts are disabled for this video."}
|
| 125 |
+
except Exception as e:
|
| 126 |
+
return {"ok": False, "error": f"Transcript fetch failed: {e}"}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@tool("date_today", return_direct=False)
|
| 130 |
+
def date_today(tz: Optional[str] = None) -> Dict[str, Any]:
|
| 131 |
+
"""
|
| 132 |
+
Return today's datetime fields.
|
| 133 |
+
"""
|
| 134 |
+
now = datetime.now(
|
| 135 |
+
ZoneInfo(tz)) if tz and ZoneInfo is not None else datetime.now()
|
| 136 |
+
return {"ok": True, "data": {"iso": now.isoformat(), "date": now.date().isoformat(), "time": now.time().isoformat(timespec="seconds")}}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@tool("date_parse", return_direct=False)
|
| 140 |
+
def date_parse(date_str: str, tz: Optional[str] = None) -> Dict[str, Any]:
|
| 141 |
+
"""
|
| 142 |
+
Parse a date/time string into ISO fields.
|
| 143 |
+
"""
|
| 144 |
+
try:
|
| 145 |
+
dt = _to_dt(date_str, tz)
|
| 146 |
+
return {"ok": True, "data": {"iso": dt.isoformat(), "date": dt.date().isoformat(), "time": dt.time().isoformat(timespec="seconds")}}
|
| 147 |
+
except Exception as e:
|
| 148 |
+
return {"ok": False, "error": f"Parse failed: {e}"}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@tool("date_add", return_direct=False)
|
| 152 |
+
def date_add(date_str: str, days: int = 0, months: int = 0, years: int = 0, tz: Optional[str] = None) -> Dict[str, Any]:
|
| 153 |
+
"""
|
| 154 |
+
Add/subtract days/months/years to a date/time.
|
| 155 |
+
"""
|
| 156 |
+
try:
|
| 157 |
+
dt = _to_dt(date_str, tz)
|
| 158 |
+
if relativedelta is not None:
|
| 159 |
+
dt2 = dt + relativedelta(days=days, months=months, years=years)
|
| 160 |
+
else:
|
| 161 |
+
if months or years:
|
| 162 |
+
return {"ok": False, "error": "Month/year arithmetic needs python-dateutil. pip install python-dateutil"}
|
| 163 |
+
dt2 = dt + timedelta(days=days)
|
| 164 |
+
return {"ok": True, "data": {"iso": dt2.isoformat(), "date": dt2.date().isoformat(), "time": dt2.time().isoformat(timespec="seconds")}}
|
| 165 |
+
except Exception as e:
|
| 166 |
+
return {"ok": False, "error": f"Add failed: {e}"}
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@tool("date_diff", return_direct=False)
|
| 170 |
+
def date_diff(start: str, end: str, unit: str = "days", tz: Optional[str] = None) -> Dict[str, Any]:
|
| 171 |
+
"""
|
| 172 |
+
Difference between two date/times. unit: days|hours|minutes|seconds.
|
| 173 |
+
"""
|
| 174 |
+
try:
|
| 175 |
+
d1 = _to_dt(start, tz)
|
| 176 |
+
d2 = _to_dt(end, tz)
|
| 177 |
+
seconds = (d2 - d1).total_seconds()
|
| 178 |
+
unit = (unit or "days").lower()
|
| 179 |
+
if unit == "seconds":
|
| 180 |
+
value = seconds
|
| 181 |
+
elif unit == "minutes":
|
| 182 |
+
value = seconds / 60
|
| 183 |
+
elif unit == "hours":
|
| 184 |
+
value = seconds / 3600
|
| 185 |
+
else:
|
| 186 |
+
unit = "days"
|
| 187 |
+
value = seconds / 86400
|
| 188 |
+
return {"ok": True, "data": {"value": value, "unit": unit}}
|
| 189 |
+
except Exception as e:
|
| 190 |
+
return {"ok": False, "error": f"Diff failed: {e}"}
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@tool("next_weekday", return_direct=False)
|
| 194 |
+
def next_weekday(date_str: str, weekday: int, include_today: bool = False, tz: Optional[str] = None) -> Dict[str, Any]:
|
| 195 |
+
"""
|
| 196 |
+
Next date matching weekday (0=Mon..6=Sun).
|
| 197 |
+
"""
|
| 198 |
+
try:
|
| 199 |
+
base = _to_dt(date_str, tz).date()
|
| 200 |
+
wd = int(weekday) % 7
|
| 201 |
+
delta = (wd - base.weekday()) % 7
|
| 202 |
+
if delta == 0 and not include_today:
|
| 203 |
+
delta = 7
|
| 204 |
+
target = base + timedelta(days=delta)
|
| 205 |
+
return {"ok": True, "data": {"date": target.isoformat(), "weekday": wd}}
|
| 206 |
+
except Exception as e:
|
| 207 |
+
return {"ok": False, "error": f"next_weekday failed: {e}"}
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@tool("date_format", return_direct=False)
|
| 211 |
+
def date_format(date_str: str, fmt: str = "%Y-%m-%d %H:%M:%S", tz: Optional[str] = None) -> Dict[str, Any]:
|
| 212 |
+
"""
|
| 213 |
+
Format a date/time string with strftime.
|
| 214 |
+
"""
|
| 215 |
+
try:
|
| 216 |
+
dt = _to_dt(date_str, tz)
|
| 217 |
+
return {"ok": True, "data": {"formatted": dt.strftime(fmt)}}
|
| 218 |
+
except Exception as e:
|
| 219 |
+
return {"ok": False, "error": f"Format failed: {e}"}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def get_tools():
|
| 223 |
+
"""
|
| 224 |
+
Returns a list of tools that can be used by the agent.
|
| 225 |
+
"""
|
| 226 |
+
tools = [
|
| 227 |
+
Tool(
|
| 228 |
+
name="BraveSearch",
|
| 229 |
+
func=BraveSearch().run,
|
| 230 |
+
description="Search the web using Brave Search."
|
| 231 |
+
),
|
| 232 |
+
Tool(
|
| 233 |
+
name="YouTubeSearch",
|
| 234 |
+
func=YouTubeSearchTool().run,
|
| 235 |
+
description="Search YouTube for videos."
|
| 236 |
+
),
|
| 237 |
+
Tool(
|
| 238 |
+
name="ExtractText",
|
| 239 |
+
func=ExtractTextTool().run,
|
| 240 |
+
description="Extract text from a given URL."
|
| 241 |
+
),
|
| 242 |
+
Tool(
|
| 243 |
+
name="DuckDuckGoSearch",
|
| 244 |
+
func=DuckDuckGoSearchResults().run,
|
| 245 |
+
description="Search the web using DuckDuckGo."
|
| 246 |
+
),
|
| 247 |
+
Tool(
|
| 248 |
+
name="GoogleSearch",
|
| 249 |
+
func=GoogleSearchResults().run,
|
| 250 |
+
description="Search the web using Google."
|
| 251 |
+
),
|
| 252 |
+
Tool(
|
| 253 |
+
name="WikipediaQuery",
|
| 254 |
+
func=WikipediaQueryRun().run,
|
| 255 |
+
description="Query Wikipedia for information."
|
| 256 |
+
),
|
| 257 |
+
Tool(
|
| 258 |
+
name="WolframAlphaQuery",
|
| 259 |
+
func=WolframAlphaQueryRun().run,
|
| 260 |
+
description="Query Wolfram Alpha for computational knowledge."
|
| 261 |
+
)
|
| 262 |
+
]
|
| 263 |
+
# Add structured tools (LangChain @tool)
|
| 264 |
+
tools.extend([
|
| 265 |
+
youtube_transcript,
|
| 266 |
+
date_today,
|
| 267 |
+
date_parse,
|
| 268 |
+
date_add,
|
| 269 |
+
date_diff,
|
| 270 |
+
next_weekday,
|
| 271 |
+
date_format,
|
| 272 |
+
])
|
| 273 |
+
return tools
|