anssio commited on
Commit
e57c8b0
·
verified ·
1 Parent(s): eb37339

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +202 -0
agent.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AnssiO 17/08/2025
2
+
3
+ from langgraph.graph import StateGraph, START, END
4
+ from langchain_core.tools import tool
5
+ from langchain_openai import ChatOpenAI
6
+ from langchain_experimental.tools.python.tool import PythonREPLTool
7
+ from youtube_transcript_api import YouTubeTranscriptApi
8
+ from urllib.parse import urlparse, parse_qs
9
+ import os
10
+ from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
11
+ from langgraph.graph import MessagesState
12
+ from langchain_tavily import TavilySearch
13
+ from huggingface_hub import InferenceClient
14
+ import time
15
+ import requests
16
+ from io import BytesIO
17
+ from pypdf import PdfReader
18
+ from bs4 import BeautifulSoup
19
+ from markdownify import markdownify as md
20
+
21
+ openai_key = os.getenv("OPENAI_API_KEY")
22
+ os.environ["OPENAI_API_KEY"] = openai_key
23
+ tavily_key = os.getenv("TAVILY_API_KEY")
24
+ os.environ["TAVILY_API_KEY"] = tavily_key
25
+
26
+
27
+ @tool
28
+ def youtube_transcript(url: str) -> str:
29
+ """Get the transcript of a YouTube video from the full URL."""
30
+ def extract_video_id(url):
31
+ parsed = urlparse(url)
32
+ if parsed.hostname == "youtu.be":
33
+ return parsed.path[1:]
34
+ elif "youtube.com" in parsed.hostname:
35
+ return parse_qs(parsed.query).get("v", [None])[0]
36
+ return None
37
+
38
+ video_id = extract_video_id(url)
39
+ if not video_id:
40
+ return "Invalid YouTube URL."
41
+ transcript = YouTubeTranscriptApi.get_transcript(video_id)
42
+ return "\n".join([t["text"] for t in transcript])
43
+
44
+ @tool
45
+ def describe_image_url(image_url: str) -> str:
46
+ """Describe an image from a public URL using GPT-4o mini."""
47
+ client = ChatOpenAI(model="gpt-4o-mini", temperature=0, max_tokens=10_000)
48
+ response = client.invoke([
49
+ {"role": "user", "content": [
50
+ {"type": "text", "text": "Describe this image."},
51
+ {"type": "image_url", "image_url": {"url": image_url}}
52
+ ]}
53
+ ])
54
+ return response.content
55
+
56
+ @tool
57
+ def calculator(expression: str) -> str:
58
+ """Evaluate a basic math expression."""
59
+ try:
60
+ return str(eval(expression))
61
+ except Exception as e:
62
+ return f"Error: {e}"
63
+
64
+ @tool
65
+ def get_webpage(page_url: str) -> str:
66
+ """Load a web page and return it to markdown if possible"""
67
+ try:
68
+ r = requests.get(page_url)
69
+ r.raise_for_status()
70
+ text = ""
71
+ # special case if page is a PDF file
72
+ if r.headers.get('Content-Type', '') == 'application/pdf':
73
+ pdf_file = BytesIO(r.content)
74
+ reader = PdfReader(pdf_file)
75
+ for page in reader.pages:
76
+ text += page.extract_text()
77
+ else:
78
+ soup = BeautifulSoup((r.text), 'html.parser')
79
+ if soup.body:
80
+ # convert to markdown
81
+ text = md(str(soup.body))
82
+ else:
83
+ # return the raw content
84
+ text = r.text
85
+ return text
86
+ except Exception as e:
87
+ return f"get_webpage_content failed: {e}"
88
+
89
+ search_tool = TavilySearch(api_key=tavily_key)
90
+
91
+ python_tool = PythonREPLTool()
92
+
93
+
94
+ tools = [
95
+ calculator,
96
+ search_tool,
97
+ python_tool,
98
+ get_webpage,
99
+ youtube_transcript,
100
+ describe_image_url,
101
+ ]
102
+
103
+
104
+ llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, max_tokens=16384)
105
+
106
+ tools_by_name = {tool.name: tool for tool in tools}
107
+ llm_with_tools = llm.bind_tools(tools)
108
+
109
+
110
+ system_prompt = """\
111
+ You are a general AI assistant with tools.
112
+ I will ask you a question. Use your tools, and answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. \
113
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
114
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
115
+ If you are asked for a number, just give your FINAL ANSWER as that number.
116
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
117
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
118
+ If you are asked to give the answer without abbreviations, please use the full spelling instead of abbreviations, e.g., transform Mr. to Mister, Dr. to Doctor, or St. to Saint.
119
+ If you use the python_repl tool (code interpreter), always end your code with `print(...)` to see the output.
120
+ """
121
+
122
+
123
+ def tool_node(state: dict):
124
+
125
+ result = []
126
+ for tool_call in state["messages"][-1].tool_calls:
127
+ tool = tools_by_name[tool_call["name"]]
128
+ observation = tool.invoke(tool_call["args"])
129
+ result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
130
+ return {"messages": result}
131
+
132
+
133
+ def llm_decision_node(state: MessagesState):
134
+ messages = state["messages"]
135
+ response = [llm_with_tools.invoke([SystemMessage(system_prompt)]+messages)]
136
+ return {"messages": response + messages}
137
+
138
+
139
+ def condition_router(state: MessagesState) -> str:
140
+ last_msg = state["messages"][-1]
141
+ if last_msg.tool_calls:
142
+ return "continue"
143
+ return END
144
+
145
+
146
+ builder = StateGraph(MessagesState)
147
+
148
+ # Nodes
149
+ builder.add_node("tool_node", tool_node)
150
+ builder.add_node("llm_decision", llm_decision_node)
151
+
152
+ # # Entry
153
+ builder.add_edge(START, "llm_decision")
154
+
155
+ # # Conditional loop back or exit
156
+ builder.add_conditional_edges("llm_decision", condition_router, {
157
+ END: END,
158
+ "continue": "tool_node"
159
+ })
160
+
161
+ builder.add_edge("tool_node", "llm_decision")
162
+
163
+ agent = builder.compile()
164
+
165
+
166
+ class BasicAgent:
167
+ def __init__(self):
168
+ print("BasicAgent initialized.")
169
+ def __call__(self, question: str, file_name_text="") -> str:
170
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
171
+ # create the input
172
+ if file_name_text:
173
+ file_name, suffix = file_name_text.split(".")
174
+ if suffix == "mp3":
175
+ client = InferenceClient(provider="fal-ai")
176
+ file_url = "https://agents-course-unit4-scoring.hf.space/files/" + file_name
177
+ try:
178
+ audio_text = client.automatic_speech_recognition(file_url, model="openai/whisper-large-v3")
179
+ question = question + " The attached audio has been translated to text. Here is the text: " + audio_text
180
+ except:
181
+ question = question + " File URL:" + " 'https://agents-course-unit4-scoring.hf.space/files/" + file_name + "' (." + suffix + " file)"
182
+ else:
183
+ question = question + " File URL:" + " 'https://agents-course-unit4-scoring.hf.space/files/" + file_name + "' (." + suffix + " file)"
184
+ messages = [HumanMessage(content=question)]
185
+
186
+ # call the agent
187
+ messages = agent.invoke(
188
+ {"messages": messages},
189
+ {"recursion_limit": 30}
190
+ ) # maximum number of steps before hitting a stop condition
191
+
192
+ # post-process the response (keep only what's after "FINAL ANSWER:" for the exact match)
193
+ answer = str(messages["messages"][-1].content)
194
+ try:
195
+ answer = answer.split("FINAL ANSWER:")[-1].strip()
196
+ except:
197
+ print('Error in splitting final answer')
198
+
199
+ print(f"Agent returning the answer: {answer}")
200
+ return answer
201
+
202
+