Spaces:
Sleeping
Sleeping
| import os | |
| import requests | |
| from dotenv import load_dotenv | |
| from langgraph.graph import StateGraph, MessagesState, START | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langfuse.langchain import CallbackHandler | |
| from tools.web_search import web_search | |
| from tools.math import add_numbers_in_list, check_commutativity | |
| from tools.extraction import extract_data_from_excel, extract_transcript_from_youtube, extract_transcript_from_audio | |
| from rate_limiters import safe_invoke_with_retry_gemini | |
| load_dotenv(override=True) | |
| PROVIDER="google" | |
| langfuse_handler = CallbackHandler() | |
| tools = [ | |
| add_numbers_in_list, | |
| web_search, | |
| check_commutativity, | |
| extract_data_from_excel, | |
| extract_transcript_from_youtube, | |
| extract_transcript_from_audio | |
| ] | |
| # --------------- Define the agent structure ---------------- # | |
| def build_agent(provider: str = "hf"): | |
| USE_RATE_LIMITER = os.getenv("USE_RATE_LIMITER", "false").lower() == "true" | |
| print(f"Building agent with provider: {provider}") | |
| if provider == "hf": | |
| llm = HuggingFaceEndpoint( | |
| repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| task="text-generation", | |
| temperature=0.0, | |
| provider="hf-inference" | |
| ) | |
| llm = ChatHuggingFace(llm=llm) | |
| elif provider == "google": | |
| # Google Gemini | |
| llm = ChatGoogleGenerativeAI( | |
| # model="gemini-2.0-flash", | |
| model="gemini-2.5-flash-preview-05-20", | |
| # model="gemini-2.5-flash-lite-preview-06-17", | |
| max_tokens=2048, | |
| max_retries=2, | |
| ) | |
| elif provider == "openai": | |
| llm = ChatOpenAI( | |
| model="gpt-3.5-turbo", # or "gpt-3.5-turbo" | |
| temperature=0, | |
| api_key=os.getenv("OPENAI_API_KEY"), | |
| max_tokens=512 | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| # Bind the tools to the LLM | |
| llm_with_tools = llm.bind_tools(tools) | |
| # load the system prompt from the file | |
| with open("system_prompt.txt", "r", encoding="utf-8") as f: | |
| system_prompt = f.read() | |
| # Create system message with the system prompt | |
| sys_msg = SystemMessage(content=system_prompt) | |
| # --------------- Define nodes ---------------- # | |
| def assistant(state: MessagesState): | |
| """Node for the assistant to respond to user input.""" | |
| if USE_RATE_LIMITER: | |
| if provider == "google": | |
| response = safe_invoke_with_retry_gemini( | |
| llm_with_tools, | |
| [sys_msg] + state["messages"], | |
| max_retries=2, | |
| wait_seconds=60 | |
| ) | |
| else: | |
| raise ValueError(f"Rate limiting is not implemented for provider {provider}.") | |
| else: | |
| response = llm_with_tools.invoke([sys_msg] + state["messages"]) | |
| return {"messages": [response]} | |
| tool_node = ToolNode(tools=tools) | |
| # --------------- Build the state graph ---------------- # | |
| graph_builder = StateGraph(MessagesState) | |
| graph_builder.add_node("assistant", assistant) | |
| graph_builder.add_node("tools", tool_node) | |
| graph_builder.add_conditional_edges( | |
| "assistant", | |
| tools_condition, | |
| ) | |
| graph_builder.add_edge("tools", "assistant") | |
| graph_builder.add_edge(START, "assistant") | |
| return graph_builder.compile() | |
| # --------------- For manual testing ---------------- # | |
| if __name__ == "__main__": | |
| print("\n" + "-"*30 + " Agent Starting " + "-"*30) | |
| # Print run variables in a table format | |
| print(f"Provider: {PROVIDER}") | |
| print(f"Search engine used: {'DDGS' if os.getenv('USE_DDGS').lower() == 'true' else 'Tavily'}") | |
| agent = build_agent(provider=PROVIDER) # Change to "hf" for HuggingFace | |
| print("Agent built successfully.") | |
| print("-"*70) | |
| # Get questions | |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
| api_url = DEFAULT_API_URL | |
| questions_url = f"{api_url}/questions" | |
| files_url = f"{api_url}/files/" # Needs task_id | |
| # 2. Fetch Questions | |
| print(f"Fetching questions from: {questions_url}") | |
| try: | |
| response = requests.get(questions_url, timeout=15) | |
| response.raise_for_status() | |
| questions_data = response.json() | |
| if not questions_data: | |
| print("Fetched questions list is empty.") | |
| print(f"Fetched {len(questions_data)} questions.") | |
| except Exception as e: | |
| print(f"An unexpected error occurred fetching questions: {e}") | |
| # 3. Get specific question by task_id | |
| # task_id = "8e867cd7-cff9-4e6c-867a-ff5ddc2550be" # Sosa albums | |
| # task_id = "2d83110e-a098-4ebb-9987-066c06fa42d0" # Reverse text example | |
| # task_id = "cca530fc-4052-43b2-b130-b30968d8aa44" # Chess image | |
| # task_id = "4fc2f1ae-8625-45b5-ab34-ad4433bc21f8" # Dinosaur ? | |
| # task_id = "6f37996b-2ac7-44b0-8e68-6d28256631b4" # Commutativity check | |
| # task_id = "9d191bce-651d-4746-be2d-7ef8ecadb9c2" # Youtube video | |
| # task_id = "cabe07ed-9eca-40ea-8ead-410ef5e83f91" # Louvrier ? | |
| # task_id = "f918266a-b3e0-4914-865d-4faa564f1aef" # Code example | |
| # task_id = "3f57289b-8c60-48be-bd80-01f8099ca449" # at bats ? | |
| task_id = "7bd855d8-463d-4ed5-93ca-5fe35145f733" # Excel file | |
| # task_id = "5a0c1adf-205e-4841-a666-7c3ef95def9d" # Malko competition | |
| # task_id = "305ac316-eef6-4446-960a-92d80d542f82" # Poland film | |
| # task_id = "bda648d7-d618-4883-88f4-3466eabd860e" # Vietnamese | |
| # task_id = "cf106601-ab4f-4af9-b045-5295fe67b37d" # Olympics | |
| # task_id = "a0c07678-e491-4bbc-8f0b-07405144218f" # pitchers | |
| # task_id = "3cef3a44-215e-4aed-8e3b-b1e3f08063b7" # grocery list | |
| # task_id = "840bfca7-4f7b-481a-8794-c560c340185d" # Carolyn Collins Petersen | |
| # task_id = "1f975693-876d-457b-a649-393859e79bf3" # Audio (pages) | |
| # task_id = "99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3" # Audio (recipe) | |
| # get question with task_id | |
| q_data = next((item for item in questions_data if item["task_id"] == task_id), None) | |
| content = [ | |
| {"type": "text", "text": q_data["question"]} | |
| ] | |
| if q_data["file_name"] != "": | |
| file_url = f"{files_url}{task_id}" | |
| if q_data["file_name"].endswith((".png", ".jpg", ".jpeg")): | |
| content.append({"type": "image_url", "image_url": {"url": file_url}}) | |
| elif q_data["file_name"].endswith((".py")): | |
| # For code files, we can just send the text content | |
| try: | |
| response = requests.get(file_url, timeout=15) | |
| response.raise_for_status() | |
| code_content = response.text | |
| content.append({"type": "text", "text": code_content}) | |
| except Exception as e: | |
| print(f"Error fetching code file: {e}") | |
| elif q_data["file_name"].endswith((".xlsx", ".xls")): | |
| content.append({"type": "text", "text": "Excel file url: " + file_url}) | |
| elif q_data["file_name"].endswith((".mp3", ".wav")): | |
| content.append({"type": "text", "text": "Audio file url: " + file_url}) | |
| else: | |
| content.append({"type": "text", "text": f"File URL: {file_url} (file type not supported)"}) | |
| human_msg = HumanMessage(content=content) | |
| human_msg.pretty_print() | |
| try: | |
| result = agent.invoke( | |
| {"messages": [human_msg]}, | |
| config={"callbacks": [langfuse_handler]} | |
| ) | |
| for message in result["messages"]: | |
| message.pretty_print() | |
| # Result already printed inside assistant() node | |
| except Exception as e: | |
| print(f"Error: {e}") |