Mohammad Haghir commited on
Commit
d7388f1
·
1 Parent(s): 546fc11

simple solution

Browse files
Files changed (3) hide show
  1. agent_utils.py +20 -0
  2. app.py +32 -30
  3. requirements.txt +2 -1
agent_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import WikipediaLoader
2
+
3
+
4
+ def wiki_ret(question: str) -> str:
5
+ """ Retrieve docs from wikipedia """
6
+
7
+ # Search
8
+ search_docs = WikipediaLoader(query=question,
9
+ load_max_docs=2).load()
10
+
11
+ # Format
12
+ formatted_search_docs = "\n\n---\n\n".join(
13
+ [
14
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
15
+ for doc in search_docs
16
+ ]
17
+ )
18
+
19
+ return formatted_search_docs
20
+
app.py CHANGED
@@ -3,12 +3,20 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
6
- import re
7
- import json
 
 
 
 
 
8
 
9
  from langchain_groq import ChatGroq
10
  from langchain_core.messages import HumanMessage
11
- from langchain_community.document_loaders import WikipediaLoader
 
 
 
12
 
13
  # (Keep Constants as is)
14
  # --- Constants ---
@@ -18,7 +26,9 @@ groq_api_key = os.getenv("GROQ_API_KEY")
18
 
19
  llm = ChatGroq(api_key=groq_api_key, model="llama-3.3-70b-versatile")
20
 
21
-
 
 
22
 
23
 
24
  # --- Basic Agent Definition ---
@@ -26,35 +36,14 @@ llm = ChatGroq(api_key=groq_api_key, model="llama-3.3-70b-versatile")
26
  class BasicAgent:
27
  def __init__(self):
28
  print("BasicAgent initialized.")
 
29
 
30
-
31
- def wiki_ret(self, question: str) -> str:
32
- """ Retrieve docs from wikipedia """
33
-
34
- # Search query
35
- # structured_llm = llm.with_structured_output(SearchQuery)
36
- # search_query = structured_llm.invoke([search_instructions]+state['messages'])
37
-
38
- # Search
39
- search_docs = WikipediaLoader(query=question,
40
- load_max_docs=2).load()
41
-
42
- # Format
43
- formatted_search_docs = "\n\n---\n\n".join(
44
- [
45
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
46
- for doc in search_docs
47
- ]
48
- )
49
-
50
- return formatted_search_docs
51
-
52
- def __call__(self, question: str) -> str:
53
 
54
  # print(f"Agent received question (first 50 chars): {question[:50]}...")
55
  # fixed_answer = "This is a default answer. --- 1"
56
  # print(f"Agent returning fixed answer: {fixed_answer}")
57
- context = self.wiki_ret(question)
58
 
59
  prompt = f"""
60
  You are a general AI assistant. I will ask you a question.
@@ -65,8 +54,7 @@ class BasicAgent:
65
  a string, don't use articles, neither abbreviations (e.g. for cities), and write
66
  the digits in plain text unless specified otherwise. If you are asked for a comma
67
  separated list, apply the above rules depending of whether the element to be put
68
- in the list is a number or a string. Question: {question}
69
- For answering questions you can use context from wikiperdia: {context}"""
70
  # Your answer must be in the following format:
71
 
72
  # {{"task_id": "task_id_1", "model_answer": "Answer 1 from your model", "reasoning_trace": "The different steps by which your model reached answer 1"}}
@@ -84,6 +72,20 @@ class BasicAgent:
84
  # res = json.loads(json_str)
85
  return response
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def run_and_submit_all( profile: gr.OAuthProfile | None):
88
  """
89
  Fetches all questions, runs the BasicAgent on them, submits all answers,
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ # import re
7
+ # import json
8
+ from typing_extensions import TypedDict
9
+ from typing import list
10
+
11
+ import operator
12
+ from typing import Annotated
13
 
14
  from langchain_groq import ChatGroq
15
  from langchain_core.messages import HumanMessage
16
+
17
+ from langgraph.graph import START, END, StateGraph, ToolNode
18
+
19
+ from agent_utils import wiki_ret
20
 
21
  # (Keep Constants as is)
22
  # --- Constants ---
 
26
 
27
  llm = ChatGroq(api_key=groq_api_key, model="llama-3.3-70b-versatile")
28
 
29
+ class GraphState(TypedDict):
30
+ messages: Annotated[list, operator.add]
31
+ answer: str
32
 
33
 
34
  # --- Basic Agent Definition ---
 
36
  class BasicAgent:
37
  def __init__(self):
38
  print("BasicAgent initialized.")
39
+ self.graph = self.create_graph()
40
 
41
+ def agent(self, question: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # print(f"Agent received question (first 50 chars): {question[:50]}...")
44
  # fixed_answer = "This is a default answer. --- 1"
45
  # print(f"Agent returning fixed answer: {fixed_answer}")
46
+ # context = self.wiki_ret(question)
47
 
48
  prompt = f"""
49
  You are a general AI assistant. I will ask you a question.
 
54
  a string, don't use articles, neither abbreviations (e.g. for cities), and write
55
  the digits in plain text unless specified otherwise. If you are asked for a comma
56
  separated list, apply the above rules depending of whether the element to be put
57
+ in the list is a number or a string. Question: {question}"""
 
58
  # Your answer must be in the following format:
59
 
60
  # {{"task_id": "task_id_1", "model_answer": "Answer 1 from your model", "reasoning_trace": "The different steps by which your model reached answer 1"}}
 
72
  # res = json.loads(json_str)
73
  return response
74
 
75
+ def create_graph(self):
76
+ builder = StateGraph(GraphState)
77
+ builder.add_node("agent", self.agent)
78
+ builder.add_node("tools", ToolNode(tools = [wiki_ret]))
79
+
80
+ builder.add_edge(START, "agent")
81
+ builder.add_edge("agent", END)
82
+ return builder.compile()
83
+
84
+ def __call___(self, question):
85
+ response = (self.graph).invoke({"messages": question})
86
+ return response
87
+
88
+
89
  def run_and_submit_all( profile: gr.OAuthProfile | None):
90
  """
91
  Fetches all questions, runs the BasicAgent on them, submits all answers,
requirements.txt CHANGED
@@ -3,4 +3,5 @@ requests
3
  langchain-groq
4
  langchain-community
5
  langchain-core
6
- wikipedia
 
 
3
  langchain-groq
4
  langchain-community
5
  langchain-core
6
+ wikipedia
7
+ langgraph