mrhenu commited on
Commit
4868771
·
verified ·
1 Parent(s): 76f87fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -32
app.py CHANGED
@@ -1,42 +1,50 @@
 
 
 
 
 
 
 
 
1
  import os
2
- import gradio as gr
3
  import requests
4
  import pandas as pd
5
- from typing import TypedDict, Annotated, Sequence
6
  import operator
 
 
7
  from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
8
  from langchain.agents import AgentExecutor
9
  from langchain_experimental.tools import PythonREPLTool
10
- from langchain_community.tools.youtube.search import YouTubeSearchTool
11
  from langchain_community.tools.tavily_search import TavilySearchResults
12
- from langchain_core.tools import tool
13
  from langchain_openai import ChatOpenAI
14
  from langgraph.graph import StateGraph
15
  from langgraph.prebuilt import ToolNode, tools_condition
16
 
17
- # --- Custom Image Analysis Tool ---------------------------------------------
 
 
18
 
19
  @tool("image_analysis", return_direct=True)
20
  def image_analysis(image_path: str, prompt: str) -> str:
21
- """Analyze an image located at image_path and answer according to prompt.
22
- image_path: path or URL to the image file
23
- prompt: the specific question or instruction about the image
24
  Returns a textual answer.
25
  """
26
- from PIL import Image
27
  import openai
 
28
 
29
  if not os.path.exists(image_path):
30
  return "Image path not found."
31
 
32
- # Load image bytes
33
  with open(image_path, "rb") as f:
34
  img_bytes = f.read()
35
 
36
- # Send to OpenAI vision-capable model (e.g., gpt-4o with vision)
37
  client = openai.OpenAI()
38
- response = client.chat.completions.create(
39
- model="gpt-4o-mini", # vision-capable
40
  messages=[
41
  {
42
  "role": "user",
@@ -47,46 +55,50 @@ def image_analysis(image_path: str, prompt: str) -> str:
47
  }
48
  ],
49
  )
50
- return response.choices[0].message.content.strip()
51
 
52
- # --- Main Application Logic --------------------------------------------------
53
 
54
  class AgentState(TypedDict):
55
- """State schema for the LangGraph agent."""
56
  messages: Annotated[Sequence[BaseMessage], operator.add]
57
 
 
 
 
 
 
 
 
 
58
 
59
- def create_langgraph_agent():
60
- print("Initializing Advanced LangGraph Agent with vision…")
61
 
62
- SYSTEM_PROMPT = """
63
- You are a general AI assistant for the GAIA test. I will ask you a question. Report your reasoning briefly, and finish with:
64
- FINAL ANSWER: [YOUR FINAL ANSWER]
65
- Follow the formatting rules strictly.
66
- """
67
 
68
  llm = ChatOpenAI(model="gpt-4o", temperature=0)
69
 
 
70
  tools = [
71
  TavilySearchResults(max_results=3),
72
  PythonREPLTool(),
73
  YouTubeSearchTool(),
74
- image_analysis, # new vision tool
75
  ]
76
 
77
  # Optional FileManagement tools
78
  try:
79
  from langchain_community.agent_toolkits.file_management.toolkit import FileManagementToolkit
80
  tools.extend(FileManagementToolkit(root_dir=".").get_tools())
81
- except Exception:
82
- pass
 
83
 
84
  llm_with_tools = llm.bind_tools(tools)
85
 
86
- def agent_node(state):
87
- msgs = [SystemMessage(content=SYSTEM_PROMPT)] + list(state["messages"])
88
- reply = llm_with_tools.invoke(msgs)
89
- return {"messages": [reply]}
90
 
91
  graph = StateGraph(AgentState)
92
  graph.add_node("agent", agent_node)
@@ -95,6 +107,85 @@ Follow the formatting rules strictly.
95
  graph.add_conditional_edges("agent", tools_condition)
96
  graph.add_edge("tools", "agent")
97
 
98
- return graph.compile()
 
 
 
 
99
 
100
- # rest of app (run_agent, Gradio UI, evaluation) remains identical to V2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full Hugging Face Spaces app.py for GAIA agent – includes image analysis tool.
2
+ Copy‑paste this file as‑is to your Space.
3
+ Requires:
4
+ - openai>=1.7.0 (for vision)
5
+ - langchain, langchain-community, langgraph, gradio, pandas, requests, tavily-python, youtube-transcript-api
6
+ - PILLOW (installed automatically with Gradio)
7
+ """
8
+
9
  import os
 
10
  import requests
11
  import pandas as pd
12
+ import gradio as gr
13
  import operator
14
+ from typing import Sequence, Annotated, TypedDict
15
+
16
  from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
17
  from langchain.agents import AgentExecutor
18
  from langchain_experimental.tools import PythonREPLTool
 
19
  from langchain_community.tools.tavily_search import TavilySearchResults
20
+ from langchain_community.tools.youtube.search import YouTubeSearchTool
21
  from langchain_openai import ChatOpenAI
22
  from langgraph.graph import StateGraph
23
  from langgraph.prebuilt import ToolNode, tools_condition
24
 
25
+ # ------------------------ Vision Tool --------------------------------------
26
+
27
+ from langchain_core.tools import tool
28
 
29
  @tool("image_analysis", return_direct=True)
30
  def image_analysis(image_path: str, prompt: str) -> str:
31
+ """Analyze an image located at `image_path` according to `prompt`.
32
+ Example call from LLM: image_analysis{"image_path": "/mnt/data/cat.png", "prompt": "How many cats?"}
 
33
  Returns a textual answer.
34
  """
 
35
  import openai
36
+ from PIL import Image
37
 
38
  if not os.path.exists(image_path):
39
  return "Image path not found."
40
 
41
+ # Read image bytes
42
  with open(image_path, "rb") as f:
43
  img_bytes = f.read()
44
 
 
45
  client = openai.OpenAI()
46
+ completion = client.chat.completions.create(
47
+ model="gpt-4o-mini", # visioncapable
48
  messages=[
49
  {
50
  "role": "user",
 
55
  }
56
  ],
57
  )
58
+ return completion.choices[0].message.content.strip()
59
 
60
+ # --------------------- LangGraph Agent -------------------------------------
61
 
62
  class AgentState(TypedDict):
 
63
  messages: Annotated[Sequence[BaseMessage], operator.add]
64
 
65
+ SYSTEM_PROMPT = (
66
+ "You are a general AI assistant. I will ask you a question. Report your thoughts, "
67
+ "and finish your answer with the template:\nFINAL ANSWER: [YOUR FINAL ANSWER].\n\n"
68
+ "YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.\n"
69
+ "If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.\n"
70
+ "If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.\n"
71
+ "If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string."
72
+ )
73
 
 
 
74
 
75
+ def create_langgraph_agent() -> AgentExecutor:
76
+ print("Initializing LangGraph GAIA agent…")
 
 
 
77
 
78
  llm = ChatOpenAI(model="gpt-4o", temperature=0)
79
 
80
+ # Base tools
81
  tools = [
82
  TavilySearchResults(max_results=3),
83
  PythonREPLTool(),
84
  YouTubeSearchTool(),
85
+ image_analysis,
86
  ]
87
 
88
  # Optional FileManagement tools
89
  try:
90
  from langchain_community.agent_toolkits.file_management.toolkit import FileManagementToolkit
91
  tools.extend(FileManagementToolkit(root_dir=".").get_tools())
92
+ print("FileManagement tools loaded.")
93
+ except Exception as e:
94
+ print("FileManagement toolkit unavailable:", e)
95
 
96
  llm_with_tools = llm.bind_tools(tools)
97
 
98
+ def agent_node(state: AgentState):
99
+ full_msgs = [SystemMessage(content=SYSTEM_PROMPT)] + list(state["messages"])
100
+ response = llm_with_tools.invoke(full_msgs)
101
+ return {"messages": [response]}
102
 
103
  graph = StateGraph(AgentState)
104
  graph.add_node("agent", agent_node)
 
107
  graph.add_conditional_edges("agent", tools_condition)
108
  graph.add_edge("tools", "agent")
109
 
110
+ executor = graph.compile()
111
+ print("LangGraph agent compiled.")
112
+ return executor
113
+
114
+ # --------------------- Helper to run one question ---------------------------
115
 
116
+ def run_agent(agent_executor, question: str) -> str:
117
+ print("New question:", question)
118
+ try:
119
+ result = agent_executor.invoke(
120
+ {"messages": [HumanMessage(content=question)]},
121
+ config={"recursion_limit": 15},
122
+ )
123
+ answer_raw = result["messages"][-1].content
124
+ return answer_raw.split("FINAL ANSWER:")[-1].strip() if "FINAL ANSWER:" in answer_raw else answer_raw
125
+ except Exception as err:
126
+ print("Execution error:", err)
127
+ return f"Error: {err}"
128
+
129
+ # --------------------- Evaluation / Submission ----------------------------
130
+
131
+ def run_and_submit_all(profile: gr.OAuthProfile | None):
132
+ space_id = os.getenv("SPACE_ID")
133
+ if not profile:
134
+ return "Please login via the button.", None
135
+
136
+ if not (os.getenv("TAVILY_API_KEY") and os.getenv("OPENAI_API_KEY")):
137
+ return "Missing API keys (TAVILY / OPENAI)", None
138
+
139
+ try:
140
+ agent_exec = create_langgraph_agent()
141
+ except Exception as e:
142
+ return f"Error initializing agent: {e}", None
143
+
144
+ QUESTIONS_URL = "https://agents-course-unit4-scoring.hf.space/questions"
145
+ SUBMIT_URL = "https://agents-course-unit4-scoring.hf.space/submit"
146
+
147
+ try:
148
+ q_resp = requests.get(QUESTIONS_URL, timeout=20)
149
+ q_resp.raise_for_status()
150
+ questions = q_resp.json()
151
+ except Exception as e:
152
+ return f"Error fetching questions: {e}", None
153
+
154
+ answers = []
155
+ for item in questions:
156
+ tid, qtext = item.get("task_id"), item.get("question")
157
+ if tid and qtext:
158
+ answers.append({"task_id": tid, "submitted_answer": run_agent(agent_exec, qtext)})
159
+
160
+ payload = {
161
+ "username": profile.username.strip(),
162
+ "agent_code": f"https://huggingface.co/spaces/{space_id}/tree/main",
163
+ "answers": answers,
164
+ }
165
+
166
+ try:
167
+ s_resp = requests.post(SUBMIT_URL, json=payload, timeout=240)
168
+ s_resp.raise_for_status()
169
+ r = s_resp.json()
170
+ status = (
171
+ f"Submission Successful!\nUser: {r.get('username')}\n"
172
+ f"Score: {r.get('score', 'N/A')}% ({r.get('correct_count', '?')}/{r.get('total_attempted', '?')})\n"
173
+ f"Message: {r.get('message', 'No message')}"
174
+ )
175
+ return status, pd.DataFrame(answers)
176
+ except Exception as e:
177
+ return f"Error submitting answers: {e}", pd.DataFrame(answers)
178
+
179
+ # ------------------------ Gradio UI ---------------------------------------
180
+
181
+ with gr.Blocks() as demo:
182
+ gr.Markdown("# GAIA Agent Evaluation Runner (Vision‑enabled)")
183
+ gr.LoginButton()
184
+ run_btn = gr.Button("Run & Submit All Answers")
185
+ status_out = gr.Textbox(label="Run Status", lines=5, interactive=False)
186
+ table_out = gr.DataFrame(label="Questions / Answers", wrap=True)
187
+
188
+ run_btn.click(fn=run_and_submit_all, outputs=[status_out, table_out])
189
+
190
+ if __name__ == "__main__":
191
+ demo.launch()