ahnhs2k commited on
Commit
09fe5a5
·
1 Parent(s): 363a5be
Files changed (3) hide show
  1. agent.py +139 -0
  2. app.py +1 -90
  3. requirements.txt +4 -2
agent.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # agent.py
2
+ import os
3
+ import pickle
4
+
5
+ from langchain.tools.retriever import create_retriever_tool
6
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain_huggingface import HuggingFaceEmbeddings
9
+
10
+ from langchain_openai import ChatOpenAI
11
+ from langchain_core.documents import Document
12
+ from langchain_core.messages import SystemMessage, HumanMessage
13
+ from langchain_core.tools import tool
14
+ from langchain_community.tools import DuckDuckGoSearchRun
15
+
16
+ from langgraph.graph import START, StateGraph, MessagesState
17
+ from langgraph.prebuilt import ToolNode, tools_condition
18
+
19
+ ddg = DuckDuckGoSearchRun()
20
+
21
+ # -----------------------
22
+ # Tools
23
+ # -----------------------
24
+ @tool
25
+ def wiki_search(query: str) -> dict:
26
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
27
+ text = "\n\n".join(d.page_content for d in docs)
28
+ return {"wiki": text}
29
+
30
+
31
+ @tool
32
+ def arxiv_search(query: str) -> dict:
33
+ docs = ArxivLoader(query=query, load_max_docs=2).load()
34
+ text = "\n\n".join(d.page_content[:1000] for d in docs)
35
+ return {"arxiv": text}
36
+
37
+
38
+ ddg = DuckDuckGoSearchRun()
39
+
40
+ @tool
41
+ def web_search(query: str) -> dict:
42
+ """Search web using DuckDuckGo (no API key required)"""
43
+ try:
44
+ result = ddg.run(query)
45
+ return {"web": result}
46
+ except Exception as e:
47
+ return {"web": ""}
48
+
49
+
50
+ TOOLS = [wiki_search, arxiv_search, web_search]
51
+
52
+
53
+ # -----------------------
54
+ # System Prompt
55
+ # -----------------------
56
+ SYSTEM_PROMPT = """
57
+ You are solving GAIA benchmark questions.
58
+
59
+ You MUST:
60
+ - Use tools if factual information is required.
61
+ - Reason internally but DO NOT reveal reasoning.
62
+ - Output ONLY the final answer.
63
+ - No explanation.
64
+ - No extra text.
65
+ """.strip()
66
+
67
+ SYS_MSG = SystemMessage(content=SYSTEM_PROMPT)
68
+
69
+
70
+ # -----------------------
71
+ # Retriever (FAISS 유지)
72
+ # -----------------------
73
+ embeddings = HuggingFaceEmbeddings(
74
+ model_name="sentence-transformers/all-mpnet-base-v2"
75
+ )
76
+
77
+ if os.path.exists("faiss.pkl"):
78
+ with open("faiss.pkl", "rb") as f:
79
+ vector_store = pickle.load(f)
80
+ else:
81
+ seed_docs = [
82
+ Document(page_content="GAIA questions require factual exact answers."),
83
+ ]
84
+ vector_store = FAISS.from_documents(seed_docs, embeddings)
85
+ with open("faiss.pkl", "wb") as f:
86
+ pickle.dump(vector_store, f)
87
+
88
+ retriever_tool = create_retriever_tool(
89
+ retriever=vector_store.as_retriever(),
90
+ name="question_retriever",
91
+ description="Retrieve similar factual questions",
92
+ )
93
+
94
+
95
+ # -----------------------
96
+ # Graph Builder
97
+ # -----------------------
98
+ def build_agent():
99
+ llm = ChatOpenAI(
100
+ model="gpt-4o-mini",
101
+ temperature=0,
102
+ max_tokens=128,
103
+ )
104
+
105
+ llm_with_tools = llm.bind_tools(TOOLS)
106
+
107
+ def retriever(state: MessagesState):
108
+ return {"messages": [SYS_MSG] + state["messages"]}
109
+
110
+ def assistant(state: MessagesState):
111
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
112
+
113
+ builder = StateGraph(MessagesState)
114
+
115
+ builder.add_node("retriever", retriever)
116
+ builder.add_node("assistant", assistant)
117
+ builder.add_node("tools", ToolNode(TOOLS))
118
+
119
+ builder.add_edge(START, "retriever")
120
+ builder.add_edge("retriever", "assistant")
121
+ builder.add_conditional_edges("assistant", tools_condition)
122
+ builder.add_edge("tools", "assistant")
123
+
124
+ return builder.compile()
125
+
126
+
127
+ # -----------------------
128
+ # Public API
129
+ # -----------------------
130
+ class BasicAgent:
131
+ def __init__(self):
132
+ self.graph = build_agent()
133
+ print("✅ LangGraph GPT-4o-mini Agent initialized")
134
+
135
+ def __call__(self, question: str) -> str:
136
+ result = self.graph.invoke(
137
+ {"messages": [HumanMessage(content=question)]}
138
+ )
139
+ return result["messages"][-1].content.strip()
app.py CHANGED
@@ -7,101 +7,12 @@ import inspect
7
  import pandas as pd
8
  from typing import TypedDict
9
 
10
- from langchain_openai import ChatOpenAI
11
- from langchain_core.messages import HumanMessage
12
- from langchain_community.tools import DuckDuckGoSearchRun
13
 
14
  # (Keep Constants as is)
15
  # --- Constants ---
16
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
17
 
18
- SYSTEM_PROMPT = """
19
- You are solving GAIA benchmark questions.
20
-
21
- You MUST:
22
- - Use the provided search results as the source of truth.
23
- - Reason internally but DO NOT show reasoning.
24
- - Output ONLY the final answer.
25
- - No explanation.
26
- - No extra text.
27
- """
28
-
29
- def clean_answer(text: str) -> str:
30
- if not text:
31
- return ""
32
- s = text.strip()
33
- s = s.replace("Final answer:", "").replace("Answer:", "").strip()
34
- s = s.splitlines()[0].strip()
35
- s = s.strip('"\'`')
36
- if len(s) > 1 and s.endswith("."):
37
- s = s[:-1].strip()
38
- return s
39
-
40
- # -------------------------------
41
- # State
42
- # -------------------------------
43
- class AgentState(TypedDict):
44
- question: str
45
- answer: str
46
-
47
- # -------------------------------
48
- # Tools & LLM
49
- # -------------------------------
50
- # Search tool (무료)
51
- search_tool = DuckDuckGoSearchRun()
52
-
53
- # LLM (OpenAI – 이미 네 환경에서 동작 확인됨)
54
- llm = ChatOpenAI(
55
- model="gpt-4o",
56
- temperature=0,
57
- max_tokens=96,
58
- )
59
-
60
- # -------------------------------
61
- # Agent
62
- # -------------------------------
63
- class BasicAgent:
64
- def __init__(self):
65
- print("Search-based GAIA Agent initialized.")
66
-
67
- def __call__(self, question: str) -> str:
68
- print(f"Question: {question[:80]}...")
69
-
70
- queries = [
71
- question,
72
- f"{question} wikipedia",
73
- f"{question} site:wikipedia.org",
74
- f"{question} fact",
75
- ]
76
-
77
- snippets = []
78
- for q in queries:
79
- try:
80
- r = search_tool.run(q)
81
- if r:
82
- snippets.append(r)
83
- time.sleep(0.5) # rate-limit 회피
84
- except Exception as e:
85
- print("Search error:", e)
86
-
87
- search_result = "\n\n".join(snippets)
88
-
89
- prompt = f"""
90
- {SYSTEM_PROMPT}
91
-
92
- Question:
93
- {question}
94
-
95
- Search Results:
96
- {search_result}
97
- """.strip()
98
-
99
- response = llm.invoke([HumanMessage(content=prompt)])
100
- answer = clean_answer(response.content)
101
-
102
- print(f"Answer: {answer}")
103
- return answer
104
-
105
  def run_and_submit_all( profile: gr.OAuthProfile | None):
106
  """
107
  Fetches all questions, runs the BasicAgent on them, submits all answers,
 
7
  import pandas as pd
8
  from typing import TypedDict
9
 
10
+ from agent import BasicAgent
 
 
11
 
12
  # (Keep Constants as is)
13
  # --- Constants ---
14
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def run_and_submit_all( profile: gr.OAuthProfile | None):
17
  """
18
  Fetches all questions, runs the BasicAgent on them, submits all answers,
requirements.txt CHANGED
@@ -1,8 +1,10 @@
1
  gradio
2
  requests
 
3
  langgraph
4
- langchain_openai
5
- langchain_core
6
  langchain-community
 
 
7
  ddgs
8
  duckduckgo-search
 
1
  gradio
2
  requests
3
+
4
  langgraph
5
+ langchain-core
 
6
  langchain-community
7
+ langchain-openai
8
+
9
  ddgs
10
  duckduckgo-search