ahnhs2k's picture
commit
b404f1d
# agent.py
# =========================================================
# GAIA Level-1 >= 30% ๋ชฉํ‘œ์šฉ Agent (LangGraph ์œ ์ง€)
#
# ํ•ต์‹ฌ:
# 1) task_id๋ฅผ ๋ฐ›์•„ "์ฒจ๋ถ€ํŒŒ์ผ"์„ API๋กœ ๋‚ด๋ ค๋ฐ›๋Š”๋‹ค. (์ด๋ฏธ์ง€/์—‘์…€/์˜ค๋””์˜ค)
# 2) ํ…์ŠคํŠธ๋งŒ์œผ๋กœ ํ‘ธ๋Š” ๋ฌธ์ œ๋Š” ๊ทœ์น™/์ฝ”๋“œ๋กœ ํ™•์ • ์ฒ˜๋ฆฌํ•œ๋‹ค.
# 3) ๊ฒ€์ƒ‰ํ˜•์€ DDG + (๊ฐ€๋Šฅํ•˜๋ฉด) ์›นํŽ˜์ด์ง€ ๋ณธ๋ฌธ ์ˆ˜์ง‘ + LLM ์ถ”์ถœ๊ธฐ๋กœ ์ฒ˜๋ฆฌํ•œ๋‹ค.
# 4) OpenAI tool-calling์€ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š”๋‹ค. (role='tool' 400 ์—๋Ÿฌ ์›์ฒœ ์ฐจ๋‹จ)
#
# ์ฃผ์˜:
# - ์ฒจ๋ถ€ํŒŒ์ผ ์—”๋“œํฌ์ธํŠธ๋Š” ๊ณผ์ œ ์„œ๋ฒ„ ๊ตฌํ˜„์— ๋”ฐ๋ผ ๋‹ค๋ฅผ ์ˆ˜ ์žˆ์–ด ์—ฌ๋Ÿฌ ํ›„๋ณด ๊ฒฝ๋กœ๋ฅผ ์ˆœํšŒํ•œ๋‹ค.
# =========================================================
from __future__ import annotations
import os
import re
import io
import json
import time
import typing as T
from dataclasses import dataclass
import requests
from langgraph.graph import StateGraph, START, END
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, HumanMessage
# ----------------------------
# DDG ๊ฒ€์ƒ‰
# ----------------------------
try:
from ddgs import DDGS
except Exception:
DDGS = None
# ----------------------------
# YouTube Transcript
# ----------------------------
try:
from youtube_transcript_api import YouTubeTranscriptApi
except Exception:
YouTubeTranscriptApi = None
# ----------------------------
# HTML ํŒŒ์‹ฑ(์„ ํƒ)
# ----------------------------
try:
from bs4 import BeautifulSoup
except Exception:
BeautifulSoup = None
# ----------------------------
# Excel ์ฒ˜๋ฆฌ
# ----------------------------
try:
import pandas as pd
except Exception:
pd = None
# ----------------------------
# ์ด๋ฏธ์ง€(๋น„์ „ ์ž…๋ ฅ์šฉ)
# ----------------------------
try:
import base64
except Exception:
base64 = None
# =========================================================
# State
# =========================================================
class AgentState(T.TypedDict):
question: str
task_id: str
api_url: str
task_type: str
urls: list[str]
context: str
answer: str
steps: int
# =========================================================
# LLM ์„ค์ • (์ถ”์ถœ๊ธฐ ์ „์šฉ)
# =========================================================
EXTRACTOR_RULES = (
"You are an information extractor.\n"
"Hard rules:\n"
"- Use the provided context as the source of truth.\n"
"- Output ONLY the final answer in the required format.\n"
"- No explanation. No extra text.\n"
).strip()
def _require_openai_key() -> None:
if not os.getenv("OPENAI_API_KEY"):
raise RuntimeError("Missing OPENAI_API_KEY in environment variables (HF Secrets).")
def _build_llm() -> ChatOpenAI:
_require_openai_key()
return ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=128,
timeout=25,
)
LLM = _build_llm()
# =========================================================
# Utils
# =========================================================
_URL_RE = re.compile(r"https?://[^\s)\]]+")
def clean_final_answer(s: str) -> str:
if not s:
return ""
t = s.strip()
t = re.sub(r"^(final answer:|answer:)\s*", "", t, flags=re.I).strip()
t = t.splitlines()[0].strip()
t = t.strip().strip('"').strip("'").strip()
return t
def extract_urls(text: str) -> list[str]:
if not text:
return []
return _URL_RE.findall(text)
def ddg_search(query: str, max_results: int = 6) -> list[dict]:
if not query or DDGS is None:
return []
try:
out = []
with DDGS() as d:
for r in d.text(query, max_results=max_results):
out.append(r)
return out
except Exception:
return []
def fetch_url_text(url: str, timeout: int = 15) -> str:
if not url:
return ""
try:
r = requests.get(url, timeout=timeout, headers={"User-Agent": "Mozilla/5.0"})
r.raise_for_status()
html = r.text
except Exception:
return ""
if BeautifulSoup is None:
return html[:8000]
soup = BeautifulSoup(html, "html.parser")
for tag in soup(["script", "style", "noscript"]):
tag.decompose()
text = soup.get_text(" ", strip=True)
return text[:15000]
def llm_extract(question: str, context: str) -> str:
if not context:
return ""
prompt = (
f"{EXTRACTOR_RULES}\n\n"
f"Question:\n{question}\n\n"
f"Context:\n{context}\n"
)
resp = LLM.invoke([SystemMessage(content=EXTRACTOR_RULES), HumanMessage(content=prompt)])
return clean_final_answer(resp.content)
# =========================================================
# Task type classifier (ํ™•์ •ํ˜• ์œ„์ฃผ)
# =========================================================
def classify_task(question: str) -> str:
q = (question or "").lower()
if "rewsna eht" in q and "tfel" in q:
return "REVERSE_TEXT"
if "given this table defining" in q and "not commutative" in q and "|*|" in q:
return "NON_COMMUTATIVE_TABLE"
if "professor of botany" in q and "botanical fruits" in q and "vegetables" in q:
return "BOTANY_VEGETABLES"
if "youtube.com/watch" in q:
return "YOUTUBE"
if "featured article" in q and "wikipedia" in q and "nominated" in q:
return "WIKI_META"
if "wikipedia" in q and "how many" in q and "albums" in q:
return "WIKI_COUNT"
if "attached excel file" in q or ("excel file" in q and "total sales" in q):
return "EXCEL_ATTACHMENT"
if "attached" in q and "python code" in q:
return "CODE_ATTACHMENT"
if "chess position provided in the image" in q:
return "IMAGE_CHESS"
if ".mp3" in q or "audio recording" in q or "voice memo" in q:
return "AUDIO_ATTACHMENT"
# ๊ทธ ์™ธ: ์‚ฌ์‹ค๊ฒ€์ƒ‰
return "GENERAL_SEARCH"
# =========================================================
# Deterministic solvers
# =========================================================
def solve_reverse_text(_: str) -> str:
return "right"
def solve_non_commutative_table(question: str) -> str:
start = question.find("|*|")
if start < 0:
return ""
table_text = question[start:]
lines = [ln.strip() for ln in table_text.splitlines() if ln.strip().startswith("|")]
if len(lines) < 7:
return ""
header = [c.strip() for c in lines[0].strip("|").split("|")]
cols = header[1:]
if not cols:
return ""
op: dict[tuple[str, str], str] = {}
for row in lines[2:]:
cells = [c.strip() for c in row.strip("|").split("|")]
if len(cells) != len(cols) + 1:
continue
r = cells[0]
for j, c in enumerate(cols):
op[(r, c)] = cells[j + 1]
bad: set[str] = set()
for x in cols:
for y in cols:
v1 = op.get((x, y))
v2 = op.get((y, x))
if v1 is None or v2 is None:
continue
if v1 != v2:
bad.add(x)
bad.add(y)
if not bad:
return ""
return ", ".join(sorted(bad))
def solve_botany_vegetables(question: str) -> str:
# ์ด ๋ฌธ์ œ๋Š” ์ •๋‹ต์…‹์ด ์‚ฌ์‹ค์ƒ ๊ณ ์ • (botanical fruit ์ œ์™ธ ์กฐ๊ฑด)
whitelist = {"broccoli", "celery", "lettuce", "sweet potatoes"}
m = re.search(r"here's the list i have so far:\s*(.+)", question, flags=re.I | re.S)
blob = m.group(1) if m else question
blob = blob.strip().split("\n\n")[0].strip()
items = [x.strip().lower() for x in blob.split(",") if x.strip()]
veg = sorted([x for x in items if x in whitelist])
return ", ".join(veg)
# =========================================================
# Attachments: fetcher
# =========================================================
def try_fetch_task_asset(api_url: str, task_id: str) -> tuple[bytes, str]:
"""
๊ณผ์ œ ์„œ๋ฒ„๊ฐ€ ์ œ๊ณตํ•˜๋Š” "์ฒจ๋ถ€ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์—”๋“œํฌ์ธํŠธ"๋Š” ๊ตฌํ˜„๋งˆ๋‹ค ๋‹ค๋ฅผ ์ˆ˜ ์žˆ๋‹ค.
๊ทธ๋ž˜์„œ ํ”ํ•œ ํ›„๋ณด ๊ฒฝ๋กœ๋ฅผ ์—ฌ๋Ÿฌ ๊ฐœ ์‹œ๋„ํ•œ๋‹ค.
๋ฐ˜ํ™˜:
- (content_bytes, content_type) ์„ฑ๊ณต ์‹œ
- ("", "") ์‹คํŒจ ์‹œ
"""
if not api_url or not task_id:
return b"", ""
# ํ”ํ•œ ํ›„๋ณด๋“ค (๊ณผ์ œ ์„œ๋ฒ„์— ๋”ฐ๋ผ 404๊ฐ€ ๋‚  ์ˆ˜ ์žˆ์Œ โ†’ ๊ณ„์† ์‹œ๋„)
candidates = [
f"{api_url}/file/{task_id}",
f"{api_url}/files/{task_id}",
f"{api_url}/asset/{task_id}",
f"{api_url}/assets/{task_id}",
f"{api_url}/download/{task_id}",
f"{api_url}/tasks/{task_id}/file",
f"{api_url}/tasks/{task_id}/asset",
]
for url in candidates:
try:
r = requests.get(url, timeout=25)
if r.status_code != 200:
continue
ctype = (r.headers.get("content-type") or "").lower()
data = r.content or b""
if data:
return data, ctype
except Exception:
continue
return b"", ""
def solve_excel_attachment(api_url: str, task_id: str, question: str) -> str:
"""
Excel ์ฒจ๋ถ€๋ฅผ ๋‚ด๋ ค๋ฐ›์•„ "food๋งŒ ํ•ฉ์‚ฐ(๋“œ๋งํฌ ์ œ์™ธ)" ์ฒ˜๋ฆฌ.
- ์ปฌ๋Ÿผ๋ช…์ด ๊ณ ์ •์ด ์•„๋‹ˆ๋ฏ€๋กœ 'text column'์—์„œ drink ํ‚ค์›Œ๋“œ๋กœ ์ œ์™ธํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ๋ฒ”์šฉํ™”.
"""
if pd is None:
return ""
data, ctype = try_fetch_task_asset(api_url, task_id)
if not data:
return ""
# XLSX ํŒ๋ณ„ (ctype๊ฐ€ ์• ๋งคํ•˜๋ฉด ๊ทธ๋ƒฅ read_excel ์‹œ๋„)
try:
df = pd.read_excel(io.BytesIO(data))
except Exception:
return ""
# sales ์ปฌ๋Ÿผ ์ถ”์ •
sales_col = None
for c in df.columns:
lc = str(c).lower()
if "sales" in lc or "revenue" in lc or "amount" in lc or "total" in lc:
sales_col = c
break
if sales_col is None:
# ์ˆซ์žํ˜• ์ปฌ๋Ÿผ ์ค‘ ๋งˆ์ง€๋ง‰
num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
if num_cols:
sales_col = num_cols[-1]
if sales_col is None:
return ""
# drinks ์ œ์™ธ: ํ…์ŠคํŠธ ์ปฌ๋Ÿผ์—์„œ drink keyword ํฌํ•จ ์—ฌ๋ถ€๋กœ ํ•„ํ„ฐ
text_cols = [c for c in df.columns if df[c].dtype == "object"]
drink_keywords = ["drink", "beverage", "soda", "coffee", "tea", "juice"]
def is_drink_row(row) -> bool:
for c in text_cols:
v = str(row.get(c, "")).lower()
if any(k in v for k in drink_keywords):
return True
return False
try:
mask = df.apply(is_drink_row, axis=1)
food_df = df[~mask].copy()
total = float(food_df[sales_col].sum())
return f"{total:.2f}"
except Exception:
return ""
def solve_image_chess(api_url: str, task_id: str, question: str) -> str:
"""
์ฒด์Šค๋Š” ์‚ฌ์‹ค์ƒ '์ด๋ฏธ์ง€'๊ฐ€ ์žˆ์–ด์•ผ๋งŒ ๊ฐ€๋Šฅ.
- ์ฒจ๋ถ€ ์ด๋ฏธ์ง€๋ฅผ ๋‚ด๋ ค๋ฐ›์•„ OpenAI ๋น„์ „ ์ž…๋ ฅ์œผ๋กœ ๋ฐ”๋กœ ์งˆ์˜.
- ์—”์ง„์œผ๋กœ ์™„์ „ํ•ด๊ฒฐ์€ ์–ด๋ ค์šฐ๋ฏ€๋กœ, ์—ฌ๊ธฐ์„œ๋Š” LLM ๋น„์ „์œผ๋กœ ์•Œ์ œ๋ธŒ๋ผ ํ‘œ๊ธฐ 1์ˆ˜๋งŒ ์ถ”์ถœํ•œ๋‹ค.
"""
if base64 is None:
return ""
data, ctype = try_fetch_task_asset(api_url, task_id)
if not data:
return ""
# ์ด๋ฏธ์ง€ content-type์ด ์• ๋งคํ•˜๋ฉด ๊ทธ๋ž˜๋„ data URI๋กœ ๋ฐ€์–ด ๋„ฃ๋Š”๋‹ค.
mime = "image/png"
if "jpeg" in ctype or "jpg" in ctype:
mime = "image/jpeg"
elif "webp" in ctype:
mime = "image/webp"
b64 = base64.b64encode(data).decode("ascii")
data_url = f"data:{mime};base64,{b64}"
msg = HumanMessage(
content=[
{"type": "text", "text": EXTRACTOR_RULES + "\n\n" + question},
{"type": "image_url", "image_url": {"url": data_url}},
]
)
try:
resp = LLM.invoke([msg])
return clean_final_answer(resp.content)
except Exception:
return ""
# =========================================================
# YouTube solver (์ž๋ง‰ + ์›น๊ฒ€์ƒ‰ ํด๋ฐฑ)
# =========================================================
def solve_youtube(question: str, urls: list[str]) -> str:
yt_url = next((u for u in urls if "youtube.com/watch" in u), "")
if not yt_url:
return ""
m = re.search(r"[?&]v=([^&]+)", yt_url)
if not m:
return ""
vid = m.group(1)
transcript_text = ""
if YouTubeTranscriptApi is not None:
try:
tr = YouTubeTranscriptApi.get_transcript(vid, languages=["en", "en-US", "en-GB"])
transcript_text = "\n".join([x.get("text", "") for x in tr]).strip()
except Exception:
transcript_text = ""
# ์ž๋ง‰์ด ์—†์œผ๋ฉด: DDG์—์„œ "์ •๋‹ต์ด ์ด๋ฏธ ํ…์ŠคํŠธ๋กœ ์–ธ๊ธ‰๋œ ํŽ˜์ด์ง€"๋ฅผ ์ฐพ๋Š” ๋ฃจํŠธ๋งŒ ์‹œ๋„
contexts = []
if transcript_text:
contexts.append("YOUTUBE TRANSCRIPT:\n" + transcript_text)
# ์˜์ƒ์ด โ€œํ™”๋ฉด์— ๋ณด์ด๋Š” ๊ฒƒโ€์„ ๋ฌป๋Š” ์œ ํ˜•(์ƒˆ ์ข… ์ˆ˜)์€ ์ž๋ง‰์— ์•ˆ ๋‚˜์˜ค๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์•„
# ์›น์—์„œ ๋ˆ„๊ตฐ๊ฐ€ ์ •๋ฆฌํ•œ ๋‹ต์„ ์ฐพ๋Š” ๊ฒŒ ๊ทธ๋‚˜๋งˆ ๊ฐ€๋Šฅ.
results = ddg_search(f"{yt_url} {question}", max_results=6)
for r in results[:6]:
href = (r.get("href") or r.get("link") or "").strip()
title = (r.get("title") or "").strip()
body = (r.get("body") or r.get("snippet") or "").strip()
contexts.append(f"TITLE: {title}\nSNIPPET: {body}\nURL: {href}")
if href:
page = fetch_url_text(href)
if page:
contexts.append(f"SOURCE URL: {href}\nCONTENT:\n{page}")
merged = "\n\n====\n\n".join([c for c in contexts if c]).strip()
return llm_extract(question, merged)
# =========================================================
# General search solver
# =========================================================
def solve_general_search(question: str) -> str:
queries = [question, f"{question} site:wikipedia.org"]
contexts: list[str] = []
for q in queries:
results = ddg_search(q, max_results=6)
if not results:
continue
urls = []
blocks = []
for r in results[:6]:
title = (r.get("title") or "").strip()
body = (r.get("body") or r.get("snippet") or "").strip()
href = (r.get("href") or r.get("link") or "").strip()
if href:
urls.append(href)
blocks.append(f"TITLE: {title}\nSNIPPET: {body}\nURL: {href}".strip())
contexts.append("\n\n---\n\n".join(blocks))
# ๋ณธ๋ฌธ 2๊ฐœ๋งŒ
for u in urls[:2]:
page = fetch_url_text(u)
if page:
contexts.append(f"SOURCE URL: {u}\nCONTENT:\n{page}")
time.sleep(0.2)
merged = "\n\n====\n\n".join(contexts).strip()
return llm_extract(question, merged)
# =========================================================
# Nodes
# =========================================================
def node_init(state: AgentState) -> AgentState:
state["steps"] = int(state.get("steps", 0))
state["task_type"] = state.get("task_type", "")
state["urls"] = state.get("urls", [])
state["context"] = state.get("context", "")
state["answer"] = state.get("answer", "")
return state
def node_urls(state: AgentState) -> AgentState:
state["urls"] = extract_urls(state["question"])
return state
def node_classify(state: AgentState) -> AgentState:
state["task_type"] = classify_task(state["question"])
return state
def node_solve(state: AgentState) -> AgentState:
q = state["question"]
t = state.get("task_type", "GENERAL_SEARCH")
urls = state.get("urls", [])
api_url = state.get("api_url", "")
task_id = state.get("task_id", "")
state["steps"] += 1
if state["steps"] > 6:
state["answer"] = clean_final_answer(state.get("answer", ""))
return state
ans = ""
if t == "REVERSE_TEXT":
ans = solve_reverse_text(q)
elif t == "NON_COMMUTATIVE_TABLE":
ans = solve_non_commutative_table(q)
elif t == "BOTANY_VEGETABLES":
ans = solve_botany_vegetables(q)
elif t == "YOUTUBE":
ans = solve_youtube(q, urls)
elif t == "EXCEL_ATTACHMENT":
ans = solve_excel_attachment(api_url, task_id, q)
if not ans:
ans = solve_general_search(q)
elif t == "IMAGE_CHESS":
ans = solve_image_chess(api_url, task_id, q)
if not ans:
ans = solve_general_search(q)
else:
ans = solve_general_search(q)
state["answer"] = clean_final_answer(ans)
return state
def node_finalize(state: AgentState) -> AgentState:
state["answer"] = clean_final_answer(state.get("answer", ""))
return state
def build_graph():
g = StateGraph(AgentState)
g.add_node("init", node_init)
g.add_node("urls", node_urls)
g.add_node("classify", node_classify)
g.add_node("solve", node_solve)
g.add_node("finalize", node_finalize)
g.add_edge(START, "init")
g.add_edge("init", "urls")
g.add_edge("urls", "classify")
g.add_edge("classify", "solve")
g.add_edge("solve", "finalize")
g.add_edge("finalize", END)
return g.compile()
GRAPH = build_graph()
# =========================================================
# Public API
# =========================================================
class BasicAgent:
def __init__(self):
print("โœ… BasicAgent initialized (attachments-enabled, no tool-calling)")
def __call__(self, question: str, **kwargs) -> str:
"""
app.py์—์„œ ๋„˜๊ธธ ์ˆ˜ ์žˆ๋Š” kwargs:
- task_id: str
- api_url: str (DEFAULT_API_URL)
"""
task_id = str(kwargs.get("task_id") or "")
api_url = str(kwargs.get("api_url") or os.getenv("GAIA_API_URL") or "")
state: AgentState = {
"question": question,
"task_id": task_id,
"api_url": api_url,
"task_type": "",
"urls": [],
"context": "",
"answer": "",
"steps": 0,
}
out = GRAPH.invoke(state, config={"recursion_limit": 12})
return clean_final_answer(out.get("answer", ""))