blazingbunny commited on
Commit
22f855f
·
verified ·
1 Parent(s): b35514b

Upload 3 files

Browse files
Files changed (3) hide show
  1. agent.py +52 -32
  2. app.py +24 -1
  3. requirements.txt +2 -1
agent.py CHANGED
@@ -1,11 +1,15 @@
1
  from typing import TypedDict, Annotated, List
2
  import operator
3
  import os
 
 
4
  from langchain_google_genai import ChatGoogleGenerativeAI
5
- from langchain_core.messages import BaseMessage, HumanMessage
6
  from langgraph.graph import StateGraph, END, START
7
  from langgraph.prebuilt import ToolNode
8
  from langchain_tavily import TavilySearch
 
 
9
  from dotenv import load_dotenv
10
 
11
  load_dotenv()
@@ -15,11 +19,21 @@ class AgentState(TypedDict):
15
  messages: Annotated[List[BaseMessage], operator.add]
16
 
17
  # 2. Define the tools
18
- tools = [TavilySearch(max_results=1)]
 
 
 
 
 
 
 
 
 
 
19
  tool_node = ToolNode(tools)
20
 
21
  # 3. Define the model
22
- LLM = "gemini-2.0-flash-001"
23
  model = ChatGoogleGenerativeAI(model=LLM, temperature=0)
24
  model = model.bind_tools(tools)
25
 
@@ -27,55 +41,61 @@ model = model.bind_tools(tools)
27
  def should_continue(state):
28
  messages = state['messages']
29
  last_message = messages[-1]
30
- # If there are no tool calls, then we finish
31
  if not last_message.tool_calls:
32
  return "end"
33
- # Otherwise if there are tool calls, we continue
34
  else:
35
  return "continue"
36
 
37
  def call_model(state):
38
  messages = state['messages']
39
  response = model.invoke(messages)
40
- # We return a list, because this will get added to the existing list
41
  return {"messages": [response]}
42
 
43
  # 5. Create the graph
44
  workflow = StateGraph(AgentState)
45
-
46
- # Define the two nodes we will cycle between
47
  workflow.add_node("agent", call_model)
48
  workflow.add_node("action", tool_node)
49
-
50
- # Set the entrypoint as `agent`
51
- # This means that this node is the first one called
52
  workflow.add_edge(START, "agent")
53
-
54
- # We now add a conditional edge
55
- workflow.add_conditional_edges(
56
- "agent",
57
- should_continue,
58
- {
59
- "continue": "action",
60
- "end": END,
61
- },
62
- )
63
-
64
- # We now add a normal edge from `tools` to `agent`.
65
- # This means that after `tools` is called, `agent` node is called next.
66
  workflow.add_edge("action", "agent")
67
-
68
- # Finally, we compile it!
69
- # This compiles it into a LangChain Runnable,
70
- # meaning you can use it as you would any other runnable
71
  app = workflow.compile()
72
 
73
-
74
  class LangGraphAgent:
75
  def __init__(self):
76
  self.app = app
77
 
78
- def __call__(self, question: str) -> str:
79
- inputs = {"messages": [HumanMessage(content=question)]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  final_state = self.app.invoke(inputs)
81
- return final_state['messages'][-1].content
 
 
 
 
1
  from typing import TypedDict, Annotated, List
2
  import operator
3
  import os
4
+ import base64
5
+ import requests
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
8
  from langgraph.graph import StateGraph, END, START
9
  from langgraph.prebuilt import ToolNode
10
  from langchain_tavily import TavilySearch
11
+ from langchain_core.tools import tool
12
+ from langchain_community.document_loaders import YoutubeLoader
13
  from dotenv import load_dotenv
14
 
15
  load_dotenv()
 
19
  messages: Annotated[List[BaseMessage], operator.add]
20
 
21
  # 2. Define the tools
22
+ @tool
23
+ def get_youtube_transcript(url: str) -> str:
24
+ """Retrieves the transcript of a YouTube video given its URL."""
25
+ try:
26
+ loader = YoutubeLoader.from_youtube_url(url, add_video_info=False)
27
+ docs = loader.load()
28
+ return "\n".join([doc.page_content for doc in docs])
29
+ except Exception as e:
30
+ return f"Error getting transcript: {e}"
31
+
32
+ tools = [TavilySearch(max_results=1), get_youtube_transcript]
33
  tool_node = ToolNode(tools)
34
 
35
  # 3. Define the model
36
+ LLM = "gemini-3-pro-preview"
37
  model = ChatGoogleGenerativeAI(model=LLM, temperature=0)
38
  model = model.bind_tools(tools)
39
 
 
41
  def should_continue(state):
42
  messages = state['messages']
43
  last_message = messages[-1]
 
44
  if not last_message.tool_calls:
45
  return "end"
 
46
  else:
47
  return "continue"
48
 
49
  def call_model(state):
50
  messages = state['messages']
51
  response = model.invoke(messages)
 
52
  return {"messages": [response]}
53
 
54
  # 5. Create the graph
55
  workflow = StateGraph(AgentState)
 
 
56
  workflow.add_node("agent", call_model)
57
  workflow.add_node("action", tool_node)
 
 
 
58
  workflow.add_edge(START, "agent")
59
+ workflow.add_conditional_edges("agent", should_continue, {"continue": "action", "end": END})
 
 
 
 
 
 
 
 
 
 
 
 
60
  workflow.add_edge("action", "agent")
 
 
 
 
61
  app = workflow.compile()
62
 
 
63
  class LangGraphAgent:
64
  def __init__(self):
65
  self.app = app
66
 
67
+ def __call__(self, question: str, task_id: str = None) -> str:
68
+ messages = [
69
+ SystemMessage(content="You are a helpful assistant. Answer the user's question directly and concisely. Do not include any introductory text or 'Final Answer:'. Just output the answer. If the question involves an image or video provided in the context, analyze it to answer."),
70
+ ]
71
+
72
+ content = []
73
+ content.append({"type": "text", "text": question})
74
+
75
+ if task_id:
76
+ image_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
77
+ try:
78
+ # Check headers first
79
+ response = requests.head(image_url, timeout=5)
80
+ if response.status_code == 200 and "image" in response.headers.get("Content-Type", ""):
81
+ # Fetch the image
82
+ img_response = requests.get(image_url, timeout=10)
83
+ if img_response.status_code == 200:
84
+ image_data = base64.b64encode(img_response.content).decode("utf-8")
85
+ # Determine MIME type from header or default to jpeg
86
+ mime_type = response.headers.get("Content-Type", "image/jpeg")
87
+ content.append({
88
+ "type": "image_url",
89
+ "image_url": {"url": f"data:{mime_type};base64,{image_data}"}
90
+ })
91
+ except Exception as e:
92
+ print(f"Error checking/fetching image: {e}")
93
+
94
+ messages.append(HumanMessage(content=content))
95
+
96
+ inputs = {"messages": messages}
97
  final_state = self.app.invoke(inputs)
98
+ result = final_state['messages'][-1].content
99
+ if isinstance(result, list):
100
+ return " ".join([str(c) for c in result])
101
+ return str(result)
app.py CHANGED
@@ -81,7 +81,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None, *args):
81
  print(f"Skipping item with missing task_id or question: {item}")
82
  continue
83
  try:
84
- submitted_answer = agent(question_text)
85
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer.strip()})
86
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
87
  except Exception as e:
@@ -200,6 +200,29 @@ with gr.Blocks() as demo:
200
  outputs=[answer_textbox]
201
  )
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  if __name__ == "__main__":
204
  print("\n" + "-"*30 + " App Starting " + "-"*30)
205
  # Check for SPACE_HOST and SPACE_ID at startup for information
 
81
  print(f"Skipping item with missing task_id or question: {item}")
82
  continue
83
  try:
84
+ submitted_answer = agent(question_text, task_id=task_id)
85
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer.strip()})
86
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
87
  except Exception as e:
 
200
  outputs=[answer_textbox]
201
  )
202
 
203
+ def export_results(df):
204
+ if df is None or df.empty:
205
+ return None
206
+ file_path = "results.txt"
207
+ with open(file_path, "w", encoding="utf-8") as f:
208
+ for _, row in df.iterrows():
209
+ f.write(f"Task ID: {row.get('Task ID', 'N/A')}\n")
210
+ f.write(f"Question: {row.get('Question', 'N/A')}\n")
211
+ f.write(f"Answer: {row.get('Submitted Answer', 'N/A')}\n")
212
+ f.write("-" * 40 + "\n")
213
+ return file_path
214
+
215
+ gr.Markdown("---")
216
+ gr.Markdown("## Tools")
217
+ export_button = gr.Button("Export Results to Text")
218
+ file_output = gr.File(label="Download Results")
219
+
220
+ export_button.click(
221
+ fn=export_results,
222
+ inputs=[results_table],
223
+ outputs=[file_output]
224
+ )
225
+
226
  if __name__ == "__main__":
227
  print("\n" + "-"*30 + " App Starting " + "-"*30)
228
  # Check for SPACE_HOST and SPACE_ID at startup for information
requirements.txt CHANGED
@@ -8,4 +8,5 @@ tavily-python
8
  langchain-google-genai
9
  google-auth
10
  langchain-tavily
11
- google-cloud-aiplatform
 
 
8
  langchain-google-genai
9
  google-auth
10
  langchain-tavily
11
+ google-cloud-aiplatform
12
+ youtube-transcript-api