cowrycode commited on
Commit
59efa71
·
verified ·
1 Parent(s): 8480402

Create gaia_agent.py

Browse files
Files changed (1) hide show
  1. gaia_agent.py +307 -0
gaia_agent.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from operator import add
2
+ # from re import search
3
+ from typing import TypedDict, Annotated
4
+ from langgraph.graph.message import add_messages
5
+ #from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
6
+ from langgraph.prebuilt import ToolNode
7
+ from langgraph.graph import START, StateGraph
8
+ from langgraph.prebuilt import tools_condition
9
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
10
+ # from langchain_community.llms.ollama import Ollama
11
+ from langchain_community.tools import DuckDuckGoSearchRun
12
+ import os
13
+
14
+ #:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::#
15
+ from dotenv import load_dotenv
16
+ # from langgraph.graph import START, StateGraph, MessagesState
17
+ from langgraph.graph import START, StateGraph
18
+ from langgraph.prebuilt import tools_condition
19
+ from langgraph.prebuilt import ToolNode
20
+ # from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
21
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
22
+ from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
23
+ from langchain_core.messages import AnyMessage
24
+ from langchain_core.tools import Tool
25
+ from googleapiclient.discovery import build
26
+ from youtube_transcript_api import YouTubeTranscriptApi
27
+ from urllib.parse import parse_qs, urlparse
28
+ from openai import OpenAI
29
+ import pandas as pd
30
+ import chess
31
+ import chess.engine
32
+ # import tempfile
33
+ # from PIL import Image
34
+
35
+
36
+
37
+ #from tavily import TavilyClient
38
+ load_dotenv()
39
+ google_key = os.getenv("GOOGLE_SECRET_KEY")
40
+ my_search_engine_id = os.getenv("Google_WebSearch_Engine")
41
+ #TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
42
+ # client = TavilyClient(TAVILY_API_KEY)
43
+ OpenAI_key = os.getenv("OPENAI_API_KEY")
44
+ client = OpenAI(api_key=OpenAI_key)
45
+
46
+ yt_ap = YouTubeTranscriptApi()
47
+ #wikipedia.set_lang("en")
48
+ #:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::#
49
+
50
+ api_key = os.getenv("HF_TOKEN")
51
+
52
+
53
+ # search_tool = DuckDuckGoSearchRun()
54
+
55
+ # Generate the chat interface, including the tools
56
+ llm = HuggingFaceEndpoint(
57
+ #repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
58
+ repo_id="deepseek-ai/DeepSeek-R1-0528",
59
+ huggingfacehub_api_token=api_key,
60
+ timeout=300,
61
+ )
62
+
63
+ # # Initialize local Ollama model
64
+ # llm Ollama(model="qwen2.5-coder", base_url="http://127.0.0.1:11434")
65
+
66
+ #:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::#
67
+
68
+ def custom_multiply(__arg1: str) -> int:
69
+ # Expect something like "5,3"
70
+ a, b = map(int, __arg1.split(","))
71
+ return a * b
72
+
73
+ custom_multiply_tool = Tool(
74
+ name="custom_multiply",
75
+ func=custom_multiply,
76
+ description="Multiplies two numbers extracted from a string then returns the result.",
77
+ )
78
+
79
+ def web_search(input: str) -> str:
80
+ """Search Tavily for a query and return maximum 3 results.
81
+
82
+ Args:
83
+ query: The search query."""
84
+ service = build("customsearch", "v1", developerKey=google_key)
85
+ result = service.cse().list(q=input, cx=my_search_engine_id, num=4).execute()
86
+ formatted_docs = []
87
+ for doc in result.get("items", []):
88
+ content = doc.get("snippet", "No content available.")
89
+ source = doc.get("link", "No URL available.")
90
+
91
+ # Creating the desired XML-like output format
92
+ formatted_doc = (
93
+ f'<Document source="{source}">\n'
94
+ f'{content}\n'
95
+ f'</Document>'
96
+ )
97
+ formatted_docs.append(formatted_doc)
98
+
99
+ formatted_search_docs = "\n\n---\n\n".join(formatted_docs)
100
+
101
+ return formatted_search_docs
102
+
103
+ web_search_tool = Tool(
104
+ name="web_search",
105
+ func=web_search,
106
+ description="Useful for searching the web for relevant information to answer questions.",
107
+ )
108
+
109
+ def extract_video_id(url: str) -> str:
110
+ """
111
+ Extracts the video ID from a YouTube URL.
112
+ Args:
113
+ url (str): The full YouTube video URL.
114
+ Returns:
115
+ str: The extracted video ID or raises ValueError.
116
+ """
117
+ parsed = urlparse(url)
118
+ if parsed.hostname in {"www.youtube.com", "youtube.com"}:
119
+ qs = parse_qs(parsed.query)
120
+ if "v" in qs:
121
+ return qs["v"][0]
122
+ # fallback for youtu.be or raw IDs
123
+ return parsed.path.lstrip("/")
124
+
125
+ def fetch_youtube_details(video_url: str) -> str:
126
+ """
127
+ Fetches the transcript text for a given YouTube video.
128
+ Use the extracted transcript to answer questions about the video.
129
+ Args:
130
+ video_url (str): The YouTube video URL.
131
+ Returns:
132
+ str: Combined transcript text or an error message.
133
+ """
134
+ video_id = extract_video_id(video_url)
135
+
136
+ try:
137
+ # ✅ call on the class, NOT an instance
138
+ transcript_data = yt_ap.fetch(
139
+ video_id=video_id,
140
+ languages=["en"], #You can add as many languages, use yt_ap.list(video_id) function to get the langauges
141
+ )
142
+
143
+ #FROM TRANSCRIPT DATA, YOU CAN CREATE A OBJECT OF TRANSCRIPT SNIPET AND TIME
144
+ arr = [ {"text": snippet.text} for snippet in transcript_data]
145
+ return " ".join(f"{entry['text']}" for entry in arr)
146
+ except Exception as e:
147
+ return f"Error fetching video details: {str(e)}"
148
+
149
+ fetch_youtube_details_tool = Tool(
150
+ name="fetch_youtube_details",
151
+ func=fetch_youtube_details,
152
+ description="Fetches details from a YouTube video, including its transcript.",
153
+ )
154
+
155
+ def transcribe_audio(audio_file_path: str) -> str:
156
+ """
157
+ Transcribes speech from an audio file using OpenAI Whisper.
158
+ Use the extracted transcript to answer questions about the video.
159
+ Args:
160
+ audio_file_path
161
+ Returns:
162
+ str: Combined transcript text or an error message.
163
+ """
164
+ """Transcribe a .wav file using OpenAI Whisper."""
165
+ with open(audio_file_path, "rb") as audio_file:
166
+ response = client.audio.transcriptions.create(
167
+ model="whisper-1", # or "whisper-1" if available gpt-4o-transcribe
168
+ file=audio_file
169
+ )
170
+ return response.text
171
+
172
+ transcribe_audio_tool = Tool(
173
+ name="transcribe_audio",
174
+ func=transcribe_audio,
175
+ description="Transcribes audio from a file using OpenAI Whisper.",
176
+ )
177
+
178
+ def df_to_column_row_map(df):
179
+ """
180
+ Convert a pandas DataFrame into the format:
181
+ [
182
+ {
183
+ column1: {row1: value, row2: value, ...},
184
+ column2: {row1: value, row2: value, ...},
185
+ ...
186
+ }
187
+ ]
188
+ """
189
+ result = {}
190
+
191
+ for col in df.columns:
192
+ # Create row mapping like {row1: val1, row2: val2, ...}
193
+ col_dict = {f"row{i+1}": df.iloc[i][col] for i in range(len(df))}
194
+ result[col] = col_dict
195
+
196
+ return [result]
197
+
198
+ def excel_csv_reader(file_path: str, query: str = "") -> str:
199
+ """
200
+ Reads a CSV or Excel file and get the details as a dictionary array.
201
+ """
202
+ try:
203
+ _, ext = os.path.splitext(file_path.lower())
204
+ if ext == ".csv":
205
+ df = pd.read_csv(file_path)
206
+ elif ext in [".xls", ".xlsx"]:
207
+ df = pd.read_excel(file_path)
208
+ else:
209
+ return "Unsupported file format. Please upload CSV or Excel."
210
+
211
+ if df.empty:
212
+ return "The file is empty or unreadable."
213
+
214
+ return df_to_column_row_map(df)
215
+
216
+ except Exception as e:
217
+ return f"Error reading file: {str(e)}"
218
+
219
+ excel_csv_reader_tool = Tool(
220
+ name="excel_csv_reader",
221
+ func=excel_csv_reader,
222
+ description="Reads and summarizes data from Excel or CSV files.",
223
+ )
224
+
225
+
226
+ STOCKFISH_PATH = "/usr/local/Cellar/stockfish/17.1/bin/stockfish"
227
+
228
+ def analyze_position_from_fen(fen: str, time_limit: float = 1.0) -> str:
229
+ """
230
+ Uses Stockfish to analyze the best move from a given FEN string.
231
+ Args:
232
+ fen (str): Forsyth–Edwards Notation of the board.
233
+ time_limit (float): Time to let Stockfish think.
234
+ Returns:
235
+ str: Best move in algebraic notation.
236
+ """
237
+ try:
238
+ board = chess.Board(fen)
239
+ engine = chess.engine.SimpleEngine.popen_uci(STOCKFISH_PATH)
240
+ result = engine.play(board, chess.engine.Limit(time=time_limit))
241
+ engine.quit()
242
+ return board.san(result.move)
243
+ except Exception as e:
244
+ return f"Stockfish error: {e}"
245
+
246
+ def solve_chess_image(image_path: str) -> str:
247
+ """
248
+ Stub function for image-to-FEN. Replace with actual OCR/vision logic.
249
+
250
+ Args:
251
+ image_path (str): Path to chessboard image.
252
+ Returns:
253
+ str: Best move or error.
254
+ """
255
+ # Placeholder FEN for development (e.g., black to move, guaranteed mate)
256
+ sample_fen = "6k1/5ppp/8/8/8/8/5PPP/6K1 b - - 0 1"
257
+
258
+ try:
259
+ print(f"Simulating FEN extraction from image: {image_path}")
260
+ # Replace the above with actual OCR image-to-FEN logic
261
+ best_move = analyze_position_from_fen(sample_fen)
262
+ return f"Detected FEN: {sample_fen}\nBest move for Black: {best_move}"
263
+ except Exception as e:
264
+ return f"Image analysis error: {e}"
265
+
266
+ solve_chess_image_tool = Tool(
267
+ name="solve_chess_image",
268
+ func=solve_chess_image,
269
+ description="Analyzes a chess position from an image and suggests the best move.",
270
+ )
271
+
272
+ #:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::#
273
+
274
+ chat = ChatHuggingFace(llm=llm, verbose=True)
275
+ tools = [custom_multiply_tool, web_search_tool, fetch_youtube_details_tool, transcribe_audio_tool, excel_csv_reader_tool, solve_chess_image_tool]
276
+ # chat_with_tools = chat.bind_tools(tools)
277
+
278
+ # # Generate the AgentState and Agent graph
279
+ class AgentState(TypedDict):
280
+ messages: Annotated[list[AnyMessage], add_messages]
281
+
282
+
283
+ def build_graph():
284
+ chat_with_tools = chat.bind_tools(tools)
285
+
286
+ def assistant(state: AgentState):
287
+ return {
288
+ "messages": [chat_with_tools.invoke(state["messages"])],
289
+ }
290
+
291
+ ## The graph
292
+ builder = StateGraph(AgentState)
293
+
294
+ # Define nodes: these do the work
295
+ builder.add_node("assistant", assistant)
296
+ builder.add_node("tools", ToolNode(tools))
297
+
298
+ # Define edges: these determine how the control flow moves
299
+ builder.add_edge(START, "assistant")
300
+ builder.add_conditional_edges(
301
+ "assistant",
302
+ # If the latest message requires a tool, route to tools
303
+ # Otherwise, provide a direct response
304
+ tools_condition,
305
+ )
306
+ builder.add_edge("tools", "assistant")
307
+ return builder.compile()