skyliulu commited on
Commit
1b5518c
·
1 Parent(s): f1fae13

default google

Browse files
Files changed (1) hide show
  1. agent.py +16 -7
agent.py CHANGED
@@ -16,7 +16,7 @@ from tools import *
16
  load_dotenv()
17
 
18
 
19
- def buildAgent(provider="groq"):
20
  # load the system prompt from the file
21
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
22
  system_prompt = f.read()
@@ -32,6 +32,14 @@ def buildAgent(provider="groq"):
32
  )
33
  elif provider == "groq":
34
  llm = ChatGroq(model="qwen-qwq-32b")
 
 
 
 
 
 
 
 
35
  else:
36
  raise ValueError("Invalid provider. Choose 'groq' or 'huggingface'.")
37
 
@@ -61,7 +69,7 @@ def buildAgent(provider="groq"):
61
  def retriever(state: MessagesState):
62
  """Retriever node"""
63
  # Handle the case when no similar questions are found
64
- return {"messages": [sys_msg] + state["messages"]}
65
 
66
  ## The graph
67
  builder = StateGraph(MessagesState)
@@ -83,11 +91,12 @@ def buildAgent(provider="groq"):
83
 
84
 
85
  if __name__ == "__main__":
86
- random_question_url = "https://agents-course-unit4-scoring.hf.space/random-question"
87
- response = requests.get(random_question_url, timeout=15)
88
- questions_data = response.json()
89
- question = questions_data.get("question")
90
- graph = buildAgent(provider="groq")
 
91
  messages = [HumanMessage(content=question)]
92
  print(messages)
93
  messages = graph.invoke({"messages": messages})
 
16
  load_dotenv()
17
 
18
 
19
+ def buildAgent(provider="google"):
20
  # load the system prompt from the file
21
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
22
  system_prompt = f.read()
 
32
  )
33
  elif provider == "groq":
34
  llm = ChatGroq(model="qwen-qwq-32b")
35
+ elif provider == "google":
36
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp")
37
+ elif provider == "openrouter":
38
+ llm = ChatOpenAI(
39
+ base_url="https://openrouter.ai/api/v1",
40
+ model="google/gemini-2.0-flash-001",
41
+ api_key=os.getenv("OPENROUTER_API_KEY"),
42
+ )
43
  else:
44
  raise ValueError("Invalid provider. Choose 'groq' or 'huggingface'.")
45
 
 
69
  def retriever(state: MessagesState):
70
  """Retriever node"""
71
  # Handle the case when no similar questions are found
72
+ return {"messages": state["messages"]}
73
 
74
  ## The graph
75
  builder = StateGraph(MessagesState)
 
91
 
92
 
93
  if __name__ == "__main__":
94
+ # random_question_url = "https://agents-course-unit4-scoring.hf.space/random-question"
95
+ # response = requests.get(random_question_url, timeout=15)
96
+ # questions_data = response.json()
97
+ # question = questions_data.get("question")
98
+ question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
99
+ graph = buildAgent(provider="google")
100
  messages = [HumanMessage(content=question)]
101
  print(messages)
102
  messages = graph.invoke({"messages": messages})