Spaces:
Sleeping
Sleeping
| import base64 | |
| import os | |
| import re | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from time import sleep | |
| from typing import TypedDict, Annotated, Optional | |
| import pandas as pd | |
| import requests | |
| from langchain_community.utilities.wikipedia import WikipediaAPIWrapper | |
| from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage | |
| from langchain_core.tools import tool | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_tavily import TavilySearch | |
| from langgraph.graph import START, StateGraph | |
| from langgraph.graph.message import add_messages | |
| from langgraph.prebuilt import ToolNode | |
| from langgraph.prebuilt import tools_condition | |
| from mediawikiapi import MediaWikiAPI | |
| from wikipedia_tool import WikipediaTool | |
| from yt_tool import speech_recognition_pipe, yt_transcribe | |
| from calculus_tools import add, substract, multiple, divide | |
| def read_xlsx_file(file_path: str) -> str: | |
| """ | |
| Read a XLSX file using pandas and returns its content. | |
| Args: | |
| file_path: Path to the XLSX file | |
| Returns: | |
| Content of XLSX file as markdown or error message | |
| """ | |
| try: | |
| # Read the CSV file | |
| df = pd.read_excel(file_path) | |
| return df.to_markdown() | |
| except ImportError: | |
| return "Error: pandas is not installed. Please install it with 'pip install pandas'." | |
| except Exception as e: | |
| return f"Error analyzing CSV file: {str(e)}" | |
| class Agent: | |
| def __init__(self): | |
| llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.5-flash-preview-05-20", | |
| # model="gemini-2.0-flash", | |
| # model="gemini-1.5-pro", | |
| temperature=0 | |
| ) | |
| self.tools = [ | |
| WikipediaTool(api_wrapper=WikipediaAPIWrapper(wiki_client=MediaWikiAPI())), | |
| TavilySearch(), | |
| read_xlsx_file, | |
| add, | |
| substract, | |
| multiple, | |
| divide, | |
| yt_transcribe | |
| ] | |
| self.llm_with_tools = llm.bind_tools(self.tools) | |
| self.graph = self.build_graph() | |
| def build_graph(self): | |
| class AgentState(TypedDict): | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| task_id: str | |
| file_name: Optional[str] | |
| def assistant(state: AgentState): | |
| try: | |
| messages = state.get("messages") | |
| # Invoke the LLM with tools | |
| response = self.llm_with_tools.invoke(messages) | |
| # Ensure we return the response in the correct format | |
| return { | |
| "messages": [response] | |
| } | |
| except Exception as e: | |
| # Create an error message if something goes wrong | |
| error_msg = AIMessage(content=f"Sorry, I encountered an error: {str(e)}") | |
| return { | |
| "messages": [error_msg] | |
| } | |
| def download_file_if_any(state: AgentState) -> str: | |
| if state.get("file_name"): | |
| return "download_file" | |
| else: | |
| return "assistant" | |
| def download_file(state: AgentState): | |
| filename = state.get("file_name") | |
| task_id = state.get("task_id") | |
| url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}" | |
| try: | |
| # Send a GET request to the URL | |
| response = requests.get(url, stream=True) | |
| # Ensure the request was successful | |
| response.raise_for_status() | |
| # Create a temporary file | |
| temp_dir = tempfile.gettempdir() # Get the temporary directory path | |
| temp_file_path = os.path.join(temp_dir, os.path.basename(filename)) | |
| # Open a local file in binary write mode | |
| with open(temp_file_path, 'wb') as file: | |
| # Write the content of the response to the file | |
| for chunk in response.iter_content(chunk_size=8192): | |
| file.write(chunk) | |
| return {} | |
| except requests.exceptions.RequestException as e: | |
| error_msg = AIMessage(content=f"Sorry, I encountered an error: {str(e)}") | |
| return { | |
| "messages": [error_msg] | |
| } | |
| def file_condition(state: AgentState) -> str: | |
| filename = state.get("file_name") | |
| suffix = Path(filename).suffix | |
| if suffix in [".png", ".jpeg"]: | |
| return "add_image_message" | |
| elif suffix in [".xlsx"]: | |
| return "add_xlsx_message" | |
| elif suffix in [".mp3"]: | |
| return "add_audio_message" | |
| elif suffix in [".py"]: | |
| return "add_py_message" | |
| else: | |
| return "assistant" | |
| def add_image_message(state: AgentState): | |
| filename = state.get("file_name") | |
| temp_dir = tempfile.gettempdir() # Get the temporary directory path | |
| image_path = os.path.join(temp_dir, os.path.basename(filename)) | |
| # Load the image and convert it to base64 | |
| with open(image_path, "rb") as img_file: | |
| base64_image = base64.b64encode(img_file.read()).decode("utf-8") | |
| # Construct the image message | |
| image_message = HumanMessage(content=[{ | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{base64_image}" | |
| } | |
| }]) | |
| return { "messages" : state.get("messages") + [image_message] } | |
| def add_xlsx_message(state: AgentState): | |
| filename = state.get("file_name") | |
| temp_dir = tempfile.gettempdir() # Get the temporary directory path | |
| xlsx_path = os.path.join(temp_dir, os.path.basename(filename)) | |
| # Construct the message | |
| xlsx_message = HumanMessage(content=f"xlsx file is at {xlsx_path}") | |
| return { "messages" : state.get("messages") + [xlsx_message] } | |
| def add_audio_message(state: AgentState): | |
| filename = state.get("file_name") | |
| temp_dir = tempfile.gettempdir() # Get the temporary directory path | |
| audio_path = os.path.join(temp_dir, os.path.basename(filename)) | |
| result = speech_recognition_pipe(audio_path) | |
| audio_message = HumanMessage(result["text"]) | |
| return {"messages": state.get("messages") + [audio_message]} | |
| def add_py_message(state: AgentState): | |
| filename = state.get("file_name") | |
| temp_dir = tempfile.gettempdir() # Get the temporary directory path | |
| file_path = os.path.join(temp_dir, os.path.basename(filename)) | |
| with open(file_path, 'r') as file: | |
| content = file.read() | |
| py_message = HumanMessage(content=[{ | |
| "type": "text", | |
| "text": content | |
| }]) | |
| return {"messages": state.get("messages") + [py_message]} | |
| ## The graph | |
| builder = StateGraph(AgentState) | |
| # Define nodes: these do the work | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(self.tools)) | |
| builder.add_node("download_file", download_file) | |
| builder.add_node("add_image_message", add_image_message) | |
| builder.add_node("add_xlsx_message", add_xlsx_message) | |
| builder.add_node("add_py_message", add_py_message) | |
| builder.add_node("add_audio_message", add_audio_message) | |
| # Define edges: these determine how the control flow moves | |
| builder.add_conditional_edges( | |
| START, | |
| download_file_if_any | |
| ) | |
| # builder.add_edge("download_file", "assistant") | |
| builder.add_conditional_edges( | |
| "download_file", | |
| file_condition | |
| ) | |
| builder.add_edge("add_image_message", "assistant") | |
| builder.add_edge("add_xlsx_message", "assistant") | |
| builder.add_edge("add_py_message", "assistant") | |
| builder.add_edge("add_audio_message", "assistant") | |
| builder.add_conditional_edges( | |
| "assistant", | |
| # If the latest message requires a tool, route to tools | |
| # Otherwise, provide a direct response | |
| tools_condition | |
| ) | |
| builder.add_edge("tools", "assistant") | |
| return builder.compile() | |
| def run(self, question: str, task_id: str, file_name: str | None): | |
| system_prompt = SystemMessage(content=""" | |
| You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, use digit not letter, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. | |
| If you are asked a list of items separated by coma, add a space after each coma. | |
| If you are asked a list in alphabetical order, it is the first word of each item that matters. | |
| """) | |
| messages = [system_prompt, HumanMessage(content=question)] | |
| response = self.graph.invoke({"messages": messages, "task_id": task_id, "file_name": file_name}) | |
| answer = response['messages'][-1].content | |
| for m in response['messages']: | |
| m.pretty_print() | |
| # Regex to capture text after "FINAL ANSWER: " | |
| match = re.search(r'FINAL ANSWER:\s*(.*)', answer) | |
| if match: | |
| final_answer = match.group(1) | |
| print(final_answer) | |
| return final_answer | |
| return answer |