mabelwang21 commited on
Commit
e656aa6
·
1 Parent(s): b1d7643

update agent class with langgraph

Browse files
Files changed (1) hide show
  1. agent.py +205 -95
agent.py CHANGED
@@ -1,118 +1,228 @@
1
- from smolagents import ToolCallingAgent, tool
2
- from langchain_community.tools import DuckDuckGoSearchRun
3
- from langchain_community.utilities import WikipediaAPIWrapper
4
- from langchain.tools import BaseTool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from PIL import Image
6
  import pytesseract
7
- import fitz
8
- import ast
9
- import os
10
 
11
- # -------------------- TOOL DEFINITIONS --------------------
 
 
 
12
 
 
 
13
 
14
  @tool
15
- def web_search(query: str) -> str:
16
- """
17
- Search the web using DuckDuckGo.
18
-
19
- Args:
20
- query (str): The search query string.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- Returns:
23
- str: Summary of search results.
24
- """
25
- search = DuckDuckGoSearchRun()
26
- return search.run(query)
 
 
 
27
 
28
  @tool
29
  def wikipedia_search(query: str) -> str:
30
- """
31
- Look up a topic on Wikipedia and return relevant content.
32
-
33
- Args:
34
- query (str): The topic or term to search on Wikipedia.
35
-
36
- Returns:
37
- str: Extracted Wikipedia content.
38
- """
39
- wiki = WikipediaQueryRun()
40
- return wiki.run(query)
41
 
42
  @tool
43
  def image_recognition(image_path: str) -> str:
44
- """
45
- Perform OCR on an image to extract text.
46
-
47
- Args:
48
- image_path (str): Path to the image file.
49
-
50
- Returns:
51
- str: Extracted text from the image.
52
- """
53
- img = Image.open(image_path)
54
- return pytesseract.image_to_string(img)
55
 
56
  @tool
57
  def read_pdf(pdf_path: str) -> str:
58
- """
59
- Extract all text from a PDF document.
60
-
61
- Args:
62
- pdf_path (str): Path to the PDF file.
63
-
64
- Returns:
65
- str: Text content of the PDF.
66
- """
67
- doc = fitz.open(pdf_path)
68
- return "".join(page.get_text() for page in doc)
69
 
70
  @tool
71
- def calculate(expr: str) -> float:
72
- """
73
- Evaluate a simple math expression.
 
 
 
 
 
74
 
75
- Args:
76
- expr (str): The math expression to evaluate.
 
 
 
 
 
 
 
77
 
78
- Returns:
79
- float: Result of the expression.
80
- """
81
- def _eval(node):
82
- if isinstance(node, ast.BinOp):
83
- left = _eval(node.left)
84
- right = _eval(node.right)
85
- if isinstance(node.op, ast.Add): return left + right
86
- if isinstance(node.op, ast.Sub): return left - right
87
- if isinstance(node.op, ast.Mult): return left * right
88
- if isinstance(node.op, ast.Div): return left / right
89
- if isinstance(node.op, ast.Pow): return left ** right
90
- elif isinstance(node, ast.UnaryOp):
91
- operand = _eval(node.operand)
92
- if isinstance(node.op, ast.UAdd): return +operand
93
- if isinstance(node.op, ast.USub): return -operand
94
- elif isinstance(node, ast.Num):
95
- return node.n
96
- else:
97
- raise TypeError(f"Unsupported type: {node}")
98
- parsed = ast.parse(expr, mode='eval').body
99
- return _eval(parsed)
100
 
101
- # -------------------- AGENT CLASS --------------------
 
 
 
 
 
 
 
 
102
 
103
- tools = [web_search, wikipedia_search, image_recognition, read_pdf, calculate]
104
- HF_TOKEN = os.getenv("HF_API_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  class MyAgent:
106
- def __init__(self):
107
- from smolagents import HfApiModel
108
- self.agent = ToolCallingAgent(
109
- tools=tools,
110
- model=HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct", token=HF_TOKEN) # or another supported model
111
- )
112
-
113
- def __call__(self, question: str) -> str:
114
- try:
115
- result = self.agent.run(question)
116
- return f"FINAL ANSWER: {result.strip()}"
117
- except Exception as e:
118
- return f"FINAL ANSWER: ERROR - {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ast
3
+ import re
4
+ import operator as op
5
+ from pathlib import Path
6
+ from typing import List, TypedDict, Annotated, Optional
7
+
8
+ from langchain.tools import tool
9
+ from langchain_community.document_loaders import (
10
+ CSVLoader,
11
+ YoutubeLoader,
12
+ )
13
+
14
+ from langchain.chat_models import init_chat_model
15
+ from langchain.agents import initialize_agent, AgentType
16
+ from langchain_community.retrievers import BM25Retriever
17
+ from langchain.tools import Tool
18
+ from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
19
+ from langgraph.graph.message import add_messages
20
+ from langgraph.graph import START, StateGraph
21
+
22
+ from langgraph.prebuilt import ToolNode, tools_condition
23
+
24
+ from youtube_transcript_api import YouTubeTranscriptApi
25
  from PIL import Image
26
  import pytesseract
27
+ import fitz # PyMuPDF
 
 
28
 
29
+ # === System Prompt ===
30
+ SYSTEM_PROMPT = """
31
+ You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template:
32
+ FINAL ANSWER: [YOUR FINAL ANSWER].
33
 
34
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number nor use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string.
35
+ """.strip()
36
 
37
  @tool
38
+ def calculate(expr: str) -> str:
39
+ """Evaluate a simple math expression and return the result."""
40
+ _OPERATORS = {
41
+ ast.Add: op.add,
42
+ ast.Sub: op.sub,
43
+ ast.Mult: op.mul,
44
+ ast.Div: op.truediv,
45
+ ast.Pow: op.pow,
46
+ ast.USub: op.neg,
47
+ }
48
+ def _eval(node):
49
+ if isinstance(node, ast.Num):
50
+ return node.n
51
+ elif isinstance(node, ast.BinOp):
52
+ return _OPERATORS[type(node.op)](_eval(node.left), _eval(node.right))
53
+ elif isinstance(node, ast.UnaryOp):
54
+ return _OPERATORS[type(node.op)](_eval(node.operand))
55
+ else:
56
+ raise ValueError(f"Unsupported expression: {ast.dump(node)}")
57
+ try:
58
+ parsed = ast.parse(expr, mode='eval').body
59
+ result = _eval(parsed)
60
+ return str(result)
61
+ except Exception as e:
62
+ return f"Error calculating expression: {e}"
63
 
64
+ @tool
65
+ def web_search(query: str) -> str:
66
+ """Search the web for current information using DuckDuckGo."""
67
+ try:
68
+ from langchain.utilities import DuckDuckGoSearchRun
69
+ return DuckDuckGoSearchRun().run(query)
70
+ except Exception as e:
71
+ return f"Error performing web search: {e}"
72
 
73
  @tool
74
  def wikipedia_search(query: str) -> str:
75
+ """Search Wikipedia for a general-topic query."""
76
+ try:
77
+ from langchain.utilities import WikipediaAPIWrapper
78
+ return WikipediaAPIWrapper().run(query)
79
+ except Exception as e:
80
+ return f"Error searching Wikipedia: {e}"
 
 
 
 
 
81
 
82
  @tool
83
  def image_recognition(image_path: str) -> str:
84
+ """Analyze and extract text from an image using Tesseract OCR."""
85
+ try:
86
+ img = Image.open(image_path)
87
+ return pytesseract.image_to_string(img)
88
+ except Exception as e:
89
+ return f"Error processing image: {e}"
 
 
 
 
 
90
 
91
  @tool
92
  def read_pdf(pdf_path: str) -> str:
93
+ """Read and extract text from a PDF document."""
94
+ try:
95
+ doc = fitz.open(pdf_path)
96
+ return "".join(page.get_text() for page in doc)
97
+ except Exception as e:
98
+ return f"Error reading PDF: {e}"
 
 
 
 
 
99
 
100
  @tool
101
+ def read_csv(csv_path: str) -> str:
102
+ """Read and extract text from a CSV file, row by row."""
103
+ try:
104
+ loader = CSVLoader(csv_path, encoding='utf-8')
105
+ docs = loader.load()
106
+ return "\n".join(doc.page_content for doc in docs)
107
+ except Exception as e:
108
+ return f"Error reading CSV: {e}"
109
 
110
+ @tool
111
+ def read_spreadsheet(spreadsheet_path: str) -> str:
112
+ """Read a spreadsheet into a DataFrame and return CSV text."""
113
+ try:
114
+ import pandas as pd
115
+ df = pd.read_excel(spreadsheet_path)
116
+ return df.to_csv(index=False)
117
+ except Exception as e:
118
+ return f"Error reading spreadsheet: {e}"
119
 
120
+ @tool
121
+ def transcribe_audio(audio_path: str) -> str:
122
+ """Transcribe audio file (e.g., MP3) using Whisper."""
123
+ try:
124
+ docs = AudioLoader(audio_path).load()
125
+ transcripts = WhisperLoader().load(docs)
126
+ return "\n".join(doc.page_content for doc in transcripts)
127
+ except Exception as e:
128
+ return f"Error transcribing audio: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ @tool
131
+ def youtube_transcript_tool(video_url: str) -> str:
132
+ """Download the transcript of a YouTube video using LangChain YoutubeLoader."""
133
+ try:
134
+ loader = YoutubeLoader.from_youtube_url(video_url)
135
+ docs = loader.load()
136
+ return "\n".join(doc.page_content for doc in docs)
137
+ except Exception as e:
138
+ return f"Error fetching YouTube transcript: {e}"
139
 
140
+ @tool
141
+ def youtube_transcript_api(video_url_or_id: str) -> str:
142
+ """Download transcript from YouTube using youtube-transcript-api."""
143
+ try:
144
+ match = re.search(r"(?:v=|youtu\.be/)([A-Za-z0-9_-]{11})", video_url_or_id)
145
+ vid = match.group(1) if match else video_url_or_id
146
+ entries = YouTubeTranscriptApi.get_transcript(vid)
147
+ return " ".join(segment["text"] for segment in entries)
148
+ except Exception as e:
149
+ return f"Error fetching transcript via API: {e}"
150
+
151
+
152
+
153
+ #o3_mini = init_chat_model("openai:o3-mini", temperature=0)
154
+ #claude_sonnet = init_chat_model(anthropic:claude-3-5-sonnet-latest", temperature=0)
155
+ #gemini_2_flash = init_chat_model("google_vertexai:gemini-2.0-flash", temperature=0)
156
+
157
+ _ = os.getenv("ANTHROPIC_API_KEY")
158
+
159
+ tools = [
160
+ calculate, web_search, wikipedia_search, image_recognition,
161
+ read_pdf, read_csv, read_spreadsheet, transcribe_audio,
162
+ youtube_transcript_tool, youtube_transcript_api
163
+ ]
164
+ class AgentState(TypedDict):
165
+ # The document provided
166
+ input_file: Optional[str] # Contains file path (PDF/PNG)
167
+ messages: Annotated[list[AnyMessage], add_messages]
168
+
169
+ # === Agent Class ===
170
  class MyAgent:
171
+ def __init__(
172
+ self,
173
+ model_name: str = "anthropic:claude-3-5-sonnet-latest",
174
+ temperature: float = 0.0
175
+ ):
176
+ # Initialize LLM
177
+ self.llm = init_chat_model(model_name, temperature=temperature)
178
+
179
+ # Base tools: use provided tools or default list
180
+ self.tools = tools
181
+
182
+ # Human-readable tool descriptions
183
+ self.textual_tool_desc = "\n".join(t.__doc__.strip() for t in self.tools)
184
+
185
+ # Define assistant node
186
+ def assistant_node(state: AgentState) -> dict:
187
+ sys_msg = SystemMessage(
188
+ content="\n".join([
189
+ SYSTEM_PROMPT,
190
+ "\nTools available:\n" + self.textual_tool_desc
191
+ ])
192
+ )
193
+ msgs = [sys_msg] + state["messages"]
194
+ response = self.llm(msgs)
195
+ return {"messages": state["messages"] + [response], "input_file": state.get("input_file")}
196
+
197
+ # Condition to invoke tools: check if last LLM message mentions a tool invocation
198
+ def needs_tool(state: AgentState) -> bool:
199
+ last = state["messages"][-1].content.lower()
200
+ return any(f"{t.__name__.lower()}(" in last for t in self.tools)
201
+
202
+ # Build the state graph
203
+ builder = StateGraph(AgentState)
204
+ builder.add_node("assistant", assistant_node)
205
+ builder.add_node("tools", ToolNode(self.tools))
206
+ builder.add_edge(START, "assistant")
207
+ builder.add_conditional_edges("assistant", needs_tool)
208
+ builder.add_edge("tools", "assistant")
209
+
210
+ self.react_graph = builder.compile()
211
+
212
+ def __call__(
213
+ self,
214
+ user_input: str,
215
+ input_file: Optional[str] = None,
216
+ ) -> str:
217
+ state = AgentState()
218
+ state["messages"] = [HumanMessage(content=user_input)]
219
+ state["input_file"] = input_file
220
+ out = self.react_graph(state)
221
+ # Return only the final LLM message content
222
+ return out["messages"][-1].content.strip()
223
+
224
+ # CLI entrypoint
225
+ if __name__ == "__main__":
226
+ import fire
227
+ fire.Fire(MyAgent)
228
+