i-dhilip commited on
Commit
011ec37
·
verified ·
1 Parent(s): 893b8b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -28
app.py CHANGED
@@ -1,52 +1,44 @@
1
  import os
2
  import gradio as gr
3
  import requests
4
- import inspect
5
  import pandas as pd
6
  from dotenv import load_dotenv
7
- from typing import List, Dict, Any, Tuple, Optional
8
 
9
  # LangChain imports
10
- from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
11
- from langchain_core.messages import BaseMessage
12
- from langchain.schema import Document
13
  from langchain_openai import ChatOpenAI
14
- # from langchain_google_genai import ChatGoogleGenerativeAI
15
  from langchain_community.tools.tavily_search import TavilySearchResults
16
  from langchain_community.tools.wikipedia.tool import WikipediaQueryRun
17
  from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
18
  from langchain_community.tools.arxiv.tool import ArxivQueryRun
19
- from langgraph.graph import StateGraph, START, END
20
  from langgraph.prebuilt import ToolNode, tools_condition
21
- from dataclasses import dataclass
22
- from typing import TypedDict, Annotated, Literal
23
 
24
- # Constants
25
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
26
 
27
  class MessagesState(TypedDict):
28
  messages: List[BaseMessage]
29
 
30
- # Load system prompt
31
  try:
32
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
33
  system_prompt = f.read()
34
  except FileNotFoundError:
35
- system_prompt = """You are a helpful AI assistant that uses tools to find information and answer questions.
36
- When you don't know something, use the available tools to look up information. Be concise, direct, and provide accurate responses.
37
- Always cite your sources when using information from searches or reference materials."""
 
 
38
 
39
- # Advanced agent using LangGraph
40
  class AdvancedAgent:
41
  def __init__(self):
42
  print("Initializing AdvancedAgent with LangGraph, Wikipedia, Arxiv, and Gemini 2.0 Flash")
43
- load_dotenv() # Load environment variables from .env file
44
  self.graph = self.build_graph()
45
  print("Graph successfully built")
46
 
47
  def build_graph(self):
48
- """Build the LangGraph agent with necessary tools"""
49
-
50
  llm = ChatOpenAI(
51
  model="google/gemini-2.0-flash-001",
52
  temperature=0,
@@ -61,31 +53,25 @@ class AdvancedAgent:
61
  tools = [wikipedia_tool, arxiv_tool, tavily_search]
62
  print(f"Initialized {len(tools)} tools: Wikipedia, Arxiv, Tavily Search")
63
 
64
- sys_msg = SystemMessage(content=system_prompt)
65
  llm_with_tools = llm.bind_tools(tools)
66
 
67
  def assistant(state: MessagesState):
68
- """Assistant node that processes messages and generates responses"""
69
  messages = state["messages"]
70
- response = llm_with_tools.invoke(messages) # <-- messages, not dict!
71
- return {"messages": messages + [response]}
72
 
73
  tools_node = ToolNode(tools)
74
 
75
  builder = StateGraph(MessagesState)
76
  builder.add_node("assistant", assistant)
77
  builder.add_node("tools", tools_node)
78
-
79
- builder.set_entry_point("assistant")
80
  builder.add_edge("assistant", "tools")
81
- builder.add_edge("tools", "assistant")
82
- builder.add_edge("assistant", END)
83
  builder.add_conditional_edges(
84
  "assistant",
85
  tools_condition,
86
  {"tools": "tools", END: END}
87
  )
88
-
89
  return builder.compile()
90
 
91
  def __call__(self, question: str) -> str:
@@ -95,7 +81,8 @@ class AdvancedAgent:
95
  HumanMessage(content=question)
96
  ]
97
  try:
98
- result = self.graph.invoke(messages) # <-- messages, not dict!
 
99
  final_messages = result["messages"]
100
  ai_messages = [msg for msg in final_messages if isinstance(msg, AIMessage)]
101
  if not ai_messages:
 
1
  import os
2
  import gradio as gr
3
  import requests
 
4
  import pandas as pd
5
  from dotenv import load_dotenv
6
+ from typing import List, Dict, Any, Optional
7
 
8
  # LangChain imports
9
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
 
 
10
  from langchain_openai import ChatOpenAI
 
11
  from langchain_community.tools.tavily_search import TavilySearchResults
12
  from langchain_community.tools.wikipedia.tool import WikipediaQueryRun
13
  from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
14
  from langchain_community.tools.arxiv.tool import ArxivQueryRun
15
+ from langgraph.graph import StateGraph, END
16
  from langgraph.prebuilt import ToolNode, tools_condition
17
+ from typing import TypedDict
 
18
 
 
19
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
20
 
21
  class MessagesState(TypedDict):
22
  messages: List[BaseMessage]
23
 
 
24
  try:
25
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
26
  system_prompt = f.read()
27
  except FileNotFoundError:
28
+ system_prompt = (
29
+ "You are a helpful AI assistant that uses tools to find information and answer questions.\n"
30
+ "When you don't know something, use the available tools to look up information. Be concise, direct, and provide accurate responses.\n"
31
+ "Always cite your sources when using information from searches or reference materials."
32
+ )
33
 
 
34
  class AdvancedAgent:
35
  def __init__(self):
36
  print("Initializing AdvancedAgent with LangGraph, Wikipedia, Arxiv, and Gemini 2.0 Flash")
37
+ load_dotenv()
38
  self.graph = self.build_graph()
39
  print("Graph successfully built")
40
 
41
  def build_graph(self):
 
 
42
  llm = ChatOpenAI(
43
  model="google/gemini-2.0-flash-001",
44
  temperature=0,
 
53
  tools = [wikipedia_tool, arxiv_tool, tavily_search]
54
  print(f"Initialized {len(tools)} tools: Wikipedia, Arxiv, Tavily Search")
55
 
 
56
  llm_with_tools = llm.bind_tools(tools)
57
 
58
  def assistant(state: MessagesState):
 
59
  messages = state["messages"]
60
+ response = llm_with_tools.invoke(messages)
61
+ return {"messages": messages + [response]} # Always return dict
62
 
63
  tools_node = ToolNode(tools)
64
 
65
  builder = StateGraph(MessagesState)
66
  builder.add_node("assistant", assistant)
67
  builder.add_node("tools", tools_node)
 
 
68
  builder.add_edge("assistant", "tools")
69
+ builder.set_entry_point("assistant")
 
70
  builder.add_conditional_edges(
71
  "assistant",
72
  tools_condition,
73
  {"tools": "tools", END: END}
74
  )
 
75
  return builder.compile()
76
 
77
  def __call__(self, question: str) -> str:
 
81
  HumanMessage(content=question)
82
  ]
83
  try:
84
+ # Initial state must be a dict with "messages" key!
85
+ result = self.graph.invoke({"messages": messages})
86
  final_messages = result["messages"]
87
  ai_messages = [msg for msg in final_messages if isinstance(msg, AIMessage)]
88
  if not ai_messages: