ahnhs2k commited on
Commit
ab67ce4
ยท
1 Parent(s): 9a8f9c1
Files changed (2) hide show
  1. agent.py +130 -33
  2. requirements.txt +5 -1
agent.py CHANGED
@@ -1,46 +1,143 @@
1
  # agent.py
2
- # =========================================================
3
- # GAIA Level 1 ๋ชฉํ‘œ์šฉ ์ตœ์†Œ Agent
4
- # - ๋ชฉํ‘œ: 30% ์ด์ƒ
5
- # - ์ „๋žต: GPT-4o-mini ๋‹จ๋… ์ง๋‹ต
6
- # - ์•ˆ์ •์„ฑ ์ตœ์šฐ์„  (์ ˆ๋Œ€ ํ„ฐ์ง€์ง€ ์•Š๊ฒŒ)
7
- # =========================================================
 
8
 
9
  from langchain_openai import ChatOpenAI
10
  from langchain_core.messages import SystemMessage, HumanMessage
 
11
 
 
12
 
13
- SYSTEM_PROMPT = """
14
- You are answering GAIA benchmark questions.
15
 
16
- Rules:
17
- - Answer as concisely as possible.
18
- - Output ONLY the final answer.
19
- - No explanation.
20
- - If the answer is a list, separate items with commas.
21
- - If unsure, make your best factual guess.
22
- """.strip()
 
23
 
24
 
25
- class BasicAgent:
26
- def __init__(self):
27
- # OpenAI API Key๋Š” HF Space Secret์— OPENAI_API_KEY๋กœ ์žˆ์–ด์•ผ ํ•จ
28
- self.llm = ChatOpenAI(
29
- model="gpt-4o-mini",
30
- temperature=0,
31
- max_tokens=64,
32
- )
 
 
 
 
 
 
 
33
 
34
- self.system_msg = SystemMessage(content=SYSTEM_PROMPT)
35
- print("โœ… BasicAgent initialized (GAIA Level 1 minimal)")
36
 
37
- def __call__(self, question: str) -> str:
38
- messages = [
39
- self.system_msg,
40
- HumanMessage(content=question),
41
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- response = self.llm.invoke(messages)
 
44
 
45
- # ์•ˆ์ „์žฅ์น˜: ๋ฌด์กฐ๊ฑด str ๋ฐ˜ํ™˜
46
- return response.content.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # agent.py
2
+ # =====================================================
3
+ # GAIA Level 1 ~ 50% Target Agent
4
+ # LangGraph minimal + hard postprocess
5
+ # =====================================================
6
+
7
+ import re
8
+ from typing import TypedDict
9
 
10
  from langchain_openai import ChatOpenAI
11
  from langchain_core.messages import SystemMessage, HumanMessage
12
+ from langchain_community.tools import DuckDuckGoSearchRun
13
 
14
+ from langgraph.graph import StateGraph, START, END
15
 
 
 
16
 
17
+ # -----------------------------
18
+ # State
19
+ # -----------------------------
20
+ class State(TypedDict):
21
+ question: str
22
+ search_result: str
23
+ answer: str
24
+ searched: bool
25
 
26
 
27
+ # -----------------------------
28
+ # System Prompt (VERY IMPORTANT)
29
+ # -----------------------------
30
+ SYS = SystemMessage(
31
+ content=(
32
+ "You answer GAIA benchmark questions.\n"
33
+ "Rules:\n"
34
+ "- If factual info is needed, search ONCE.\n"
35
+ "- NEVER say you cannot access files, images, or audio.\n"
36
+ "- Output ONLY the final answer.\n"
37
+ "- No explanation.\n"
38
+ "- No quotes.\n"
39
+ "- No punctuation unless part of the answer."
40
+ )
41
+ )
42
 
 
 
43
 
44
+ # -----------------------------
45
+ # Tools
46
+ # -----------------------------
47
+ search = DuckDuckGoSearchRun()
48
+
49
+
50
+ # -----------------------------
51
+ # Postprocess (CORE)
52
+ # -----------------------------
53
+ def postprocess(text: str) -> str:
54
+ t = text.strip()
55
+
56
+ # kill refusal patterns
57
+ if re.search(r"(cannot|unable|sorry|please provide)", t.lower()):
58
+ return ""
59
+
60
+ # remove quotes
61
+ t = t.strip("\"'")
62
+
63
+ # remove trailing punctuation
64
+ t = re.sub(r"[.ใ€‚]$", "", t)
65
+
66
+ # first line only
67
+ if "\n" in t:
68
+ t = t.split("\n")[0].strip()
69
+
70
+ # numeric extraction if number-like
71
+ nums = re.findall(r"\d+", t)
72
+ if nums and len(t) > 6:
73
+ return nums[0]
74
+
75
+ return t
76
+
77
+
78
+ # -----------------------------
79
+ # Nodes
80
+ # -----------------------------
81
+ llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, max_tokens=64)
82
+
83
+ def decide_and_search(state: State):
84
+ if state["searched"]:
85
+ return state
86
 
87
+ q = state["question"]
88
+ result = search.run(q)
89
 
90
+ return {
91
+ "question": q,
92
+ "search_result": result,
93
+ "searched": True,
94
+ }
95
+
96
+ def answer(state: State):
97
+ prompt = (
98
+ f"Question:\n{state['question']}\n\n"
99
+ f"Search result:\n{state.get('search_result','')}\n\n"
100
+ "Final answer:"
101
+ )
102
+
103
+ msg = llm.invoke([SYS, HumanMessage(content=prompt)])
104
+ clean = postprocess(msg.content)
105
+
106
+ return {
107
+ **state,
108
+ "answer": clean,
109
+ }
110
+
111
+
112
+ # -----------------------------
113
+ # Graph
114
+ # -----------------------------
115
+ def build():
116
+ g = StateGraph(State)
117
+
118
+ g.add_node("search", decide_and_search)
119
+ g.add_node("answer", answer)
120
+
121
+ g.add_edge(START, "search")
122
+ g.add_edge("search", "answer")
123
+ g.add_edge("answer", END)
124
+
125
+ return g.compile()
126
+
127
+
128
+ # -----------------------------
129
+ # Public API
130
+ # -----------------------------
131
+ class BasicAgent:
132
+ def __init__(self):
133
+ self.graph = build()
134
+ print("GAIA 50% Agent ready")
135
+
136
+ def __call__(self, question: str) -> str:
137
+ out = self.graph.invoke(
138
+ {
139
+ "question": question,
140
+ "searched": False,
141
+ }
142
+ )
143
+ return out.get("answer", "")
requirements.txt CHANGED
@@ -1,4 +1,8 @@
1
  gradio
2
  requests
 
 
 
3
  langchain_openai
4
- langchain_core
 
 
1
  gradio
2
  requests
3
+ langgraph
4
+ langchain
5
+ langchain-community
6
  langchain_openai
7
+ duckduckgo-search
8
+ ddgs