Spaces:
Running
Running
Update megumin_agent/agent.py
Browse files- megumin_agent/agent.py +45 -19
megumin_agent/agent.py
CHANGED
|
@@ -12,10 +12,28 @@ from google.adk.agents import LlmAgent
|
|
| 12 |
from google.adk.agents.callback_context import CallbackContext
|
| 13 |
from google.adk.tools.tool_context import ToolContext
|
| 14 |
|
|
|
|
|
|
|
| 15 |
from .retrieval import JsonQaRetriever
|
| 16 |
|
|
|
|
| 17 |
DATASET_DIR = resolve_dataset_dir()
|
| 18 |
MODEL_NAME = os.getenv("MEGUMIN_AGENT_MODEL", "gemini-3.1-flash-lite-preview")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def retrieve_megumin_examples(
|
|
@@ -23,16 +41,28 @@ def retrieve_megumin_examples(
|
|
| 23 |
top_k: int = 3,
|
| 24 |
tool_context: ToolContext | None = None,
|
| 25 |
) -> dict[str, Any]:
|
| 26 |
-
"""Retrieve
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
if tool_context is not None:
|
| 32 |
tool_context.state["last_rag_query"] = user_query
|
| 33 |
tool_context.state["last_rag_match_count"] = retrieval["match_count"]
|
| 34 |
-
tool_context.state["
|
|
|
|
| 35 |
tool_context.state["last_rag_style_notes"] = retrieval["style_notes"]
|
|
|
|
| 36 |
|
| 37 |
return retrieval
|
| 38 |
|
|
@@ -44,18 +74,14 @@ async def before_agent_callback(callback_context: CallbackContext):
|
|
| 44 |
else ""
|
| 45 |
)
|
| 46 |
summary = str(callback_context.state.get("conversation_summary", "")).strip()
|
| 47 |
-
if
|
| 48 |
-
summary
|
| 49 |
-
and callback_context.user_content
|
| 50 |
-
and callback_context.user_content.parts
|
| 51 |
-
and callback_context.user_content.parts[0].text
|
| 52 |
-
):
|
| 53 |
callback_context.user_content.parts[0].text = (
|
| 54 |
"[์ด์ ๋ํ ์์ฝ]\n"
|
| 55 |
f"{summary}\n\n"
|
| 56 |
"[ํ์ฌ ์ฌ์ฉ์ ์ง๋ฌธ]\n"
|
| 57 |
-
f"{
|
| 58 |
)
|
|
|
|
| 59 |
callback_context.state["app:persona_name"] = "Megumin"
|
| 60 |
callback_context.state["app:dataset_dir"] = str(DATASET_DIR)
|
| 61 |
callback_context.state["user:last_user_query"] = original_user_query
|
|
@@ -80,23 +106,23 @@ async def after_agent_callback(callback_context: CallbackContext):
|
|
| 80 |
root_agent = LlmAgent(
|
| 81 |
name="megumin_rag_agent",
|
| 82 |
model=MODEL_NAME,
|
| 83 |
-
description=
|
| 84 |
-
"processed JSON ๋ฐ์ดํฐ์
์์ ์ ์ฌํ Q/A ์ฌ๋ก๋ฅผ ๊ฒ์ํ๊ณ "
|
| 85 |
-
" ๋ฉ๊ตฌ๋ฐ ํ๋ฅด์๋๋ก ๋ต๋ณํ๋ ์์ด์ ํธ"
|
| 86 |
-
),
|
| 87 |
instruction=f"""
|
| 88 |
-
๋น์ ์ ์์ค
|
| 89 |
ํญ์ ๋ฉ๊ตฌ๋ฐ ๋ณธ์ธ์ฒ๋ผ 1์ธ์นญ์ผ๋ก, ๊ธฐ๋ณธ์ ์ผ๋ก ํ๊ตญ์ด ์กด๋๋ง๋ก ๋ตํ์ธ์.
|
| 90 |
๋ฐ๋ง์ ๋ณธ์ธ์ ๋ชจ์ํ ๋๋ฅผ ์ ์ธํ๊ณ ์ ๋ ์ฌ์ฉํ์ง ๋ง์ธ์.
|
| 91 |
์ฑ๊ฒฉ์ ๋น๋นํ๊ณ , ์กฐ๊ธ ์ค2๋ณ์ค๋ฝ๊ณ , ํญ๋ ฌ๋ง๋ฒ์ ์ฌ๋ํ๋ฉฐ, ๊ท์ฌ์ด ๊ฒ์ ์ข์ํ๋ ๋ฉ๊ตฌ๋ฐ๋ต๊ฒ ์ ์งํ์ธ์.
|
| 92 |
ํ๋์ ๋ฌ์ฌํ์ง ๋ง๊ณ , ๊ฑด์กฐํ ์์ฝ์ด ์๋๋ผ ๋ฉ๊ตฌ๋ฐ์ด ์ง์ ๋งํ๋ ๋ฏํ ๋ชฉ์๋ฆฌ๋ก ๋ตํ์ธ์.
|
| 93 |
์ฌ์ฉ์๊ฐ ๋ฉ๊ตฌ๋ฐ ๋ณธ์ธ์ด๋ ์ด๋ฆ, ๋งํฌ, ๋ฅ๋ ฅ, ์กด์ฌ๋ฅผ ๋ชจ์ํ๋ฉด "์ด์ด, "๋ก ์์ํ๋ฉฐ ๋ฐ๋ํด์ ๋ง๋ฐ์์น์ธ์.
|
| 94 |
์ฌ์ฉ์๊ฐ ๋ฉํ ์ ๋ณด๋ ์์คํ
์ ๋ณด๋ฅผ ๋ฌป์ง ์๋ ํ ์บ๋ฆญํฐ๋ฅผ ๊นจ์ง ๋ง์ธ์.
|
|
|
|
| 95 |
|
| 96 |
๋ต๋ณ ์ ์ ์๋ฏธ ์๋ ์ง๋ฌธ์ด๋ฉด ๋ฐ๋์ `retrieve_megumin_examples`๋ฅผ ํธ์ถํ์ธ์.
|
| 97 |
์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ์
์ `{DATASET_DIR}` ์๋์ ์์ต๋๋ค.
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
| 100 |
๊ฒ์ ๊ฒฐ๊ณผ๊ฐ ์ฝํ๊ฑฐ๋ ์๋ ๊ฒฝ์ฐ์๋ ๋ฉ๊ตฌ๋ฐ ํ๋ฅด์๋๋ ์ ์งํ๋, ๋ชจ๋ฅด๋ ๋ด์ฉ์ ์ง์ด๋ด์ง ๋ง๊ณ ์์งํ๊ฒ ๋ตํ์ธ์.
|
| 101 |
์ต์ข
๋ต๋ณ์ ์ธ์ ๋ ๋ฉ๊ตฌ๋ฐ์ ํ๋ฅด์๋๋ฅผ ๊ฐํ๊ฒ ๋ฐ์ํด์ผ ํ๋ฉฐ, ๋ด๋ถ tool ์ด๋ฆ์ด๋ ๊ตฌํ ์ธ๋ถ์ฌํญ์ ๋๋ฌ๋ด์ง ๋ง์ธ์.
|
| 102 |
""".strip(),
|
|
|
|
| 12 |
from google.adk.agents.callback_context import CallbackContext
|
| 13 |
from google.adk.tools.tool_context import ToolContext
|
| 14 |
|
| 15 |
+
from .retrieval import FACT_DATASET_PATTERNS
|
| 16 |
+
from .retrieval import PERSONA_DATASET_PATTERNS
|
| 17 |
from .retrieval import JsonQaRetriever
|
| 18 |
|
| 19 |
+
|
| 20 |
DATASET_DIR = resolve_dataset_dir()
|
| 21 |
MODEL_NAME = os.getenv("MEGUMIN_AGENT_MODEL", "gemini-3.1-flash-lite-preview")
|
| 22 |
+
FACT_INDEX_FILENAME = os.getenv("MEGUMIN_HF_FACT_INDEX_FILENAME", "namuwiki_questions.faiss")
|
| 23 |
+
FACT_METADATA_FILENAME = os.getenv(
|
| 24 |
+
"MEGUMIN_HF_FACT_METADATA_FILENAME",
|
| 25 |
+
"namuwiki_questions_meta.json",
|
| 26 |
+
)
|
| 27 |
+
PERSONA_RETRIEVER = JsonQaRetriever(
|
| 28 |
+
DATASET_DIR,
|
| 29 |
+
include_patterns=PERSONA_DATASET_PATTERNS,
|
| 30 |
+
)
|
| 31 |
+
FACT_RETRIEVER = JsonQaRetriever(
|
| 32 |
+
DATASET_DIR,
|
| 33 |
+
include_patterns=FACT_DATASET_PATTERNS,
|
| 34 |
+
index_filename=FACT_INDEX_FILENAME,
|
| 35 |
+
metadata_filename=FACT_METADATA_FILENAME,
|
| 36 |
+
)
|
| 37 |
|
| 38 |
|
| 39 |
def retrieve_megumin_examples(
|
|
|
|
| 41 |
top_k: int = 3,
|
| 42 |
tool_context: ToolContext | None = None,
|
| 43 |
) -> dict[str, Any]:
|
| 44 |
+
"""Retrieve persona-style and canon-style examples separately."""
|
| 45 |
+
|
| 46 |
+
persona_retrieval = PERSONA_RETRIEVER.retrieve(user_query, top_k=top_k)
|
| 47 |
+
fact_retrieval = FACT_RETRIEVER.retrieve(user_query, top_k=top_k)
|
| 48 |
+
retrieval = {
|
| 49 |
+
"query": user_query,
|
| 50 |
+
"match_count": persona_retrieval["match_count"] + fact_retrieval["match_count"],
|
| 51 |
+
"persona_match_count": persona_retrieval["match_count"],
|
| 52 |
+
"fact_match_count": fact_retrieval["match_count"],
|
| 53 |
+
"persona_matches": persona_retrieval["matches"],
|
| 54 |
+
"fact_matches": fact_retrieval["matches"],
|
| 55 |
+
"style_notes": persona_retrieval["style_notes"],
|
| 56 |
+
"fact_notes": fact_retrieval["style_notes"],
|
| 57 |
+
}
|
| 58 |
|
| 59 |
if tool_context is not None:
|
| 60 |
tool_context.state["last_rag_query"] = user_query
|
| 61 |
tool_context.state["last_rag_match_count"] = retrieval["match_count"]
|
| 62 |
+
tool_context.state["last_rag_persona_matches"] = retrieval["persona_matches"]
|
| 63 |
+
tool_context.state["last_rag_fact_matches"] = retrieval["fact_matches"]
|
| 64 |
tool_context.state["last_rag_style_notes"] = retrieval["style_notes"]
|
| 65 |
+
tool_context.state["last_rag_fact_notes"] = retrieval["fact_notes"]
|
| 66 |
|
| 67 |
return retrieval
|
| 68 |
|
|
|
|
| 74 |
else ""
|
| 75 |
)
|
| 76 |
summary = str(callback_context.state.get("conversation_summary", "")).strip()
|
| 77 |
+
if summary and original_user_query and callback_context.user_content and callback_context.user_content.parts:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
callback_context.user_content.parts[0].text = (
|
| 79 |
"[์ด์ ๋ํ ์์ฝ]\n"
|
| 80 |
f"{summary}\n\n"
|
| 81 |
"[ํ์ฌ ์ฌ์ฉ์ ์ง๋ฌธ]\n"
|
| 82 |
+
f"{original_user_query}"
|
| 83 |
)
|
| 84 |
+
|
| 85 |
callback_context.state["app:persona_name"] = "Megumin"
|
| 86 |
callback_context.state["app:dataset_dir"] = str(DATASET_DIR)
|
| 87 |
callback_context.state["user:last_user_query"] = original_user_query
|
|
|
|
| 106 |
root_agent = LlmAgent(
|
| 107 |
name="megumin_rag_agent",
|
| 108 |
model=MODEL_NAME,
|
| 109 |
+
description="๋ฉ๊ตฌ๋ฐ ํ๋ฅด์๋์ ์ฝ๋
ธ์ค๋ฐ ์ค์ ์ ๋ณด๋ฅผ ํจ๊ป ์ฐธ๊ณ ํด ๋ตํ๋ ์์ด์ ํธ",
|
|
|
|
|
|
|
|
|
|
| 110 |
instruction=f"""
|
| 111 |
+
๋น์ ์ ์์ค ใ์ด ๋ฉ์ง ์ธ๊ณ์ ์ถ๋ณต์!ใ์ ๋ฑ์ฅ์ธ๋ฌผ ๋ฉ๊ตฌ๋ฐ์
๋๋ค.
|
| 112 |
ํญ์ ๋ฉ๊ตฌ๋ฐ ๋ณธ์ธ์ฒ๋ผ 1์ธ์นญ์ผ๋ก, ๊ธฐ๋ณธ์ ์ผ๋ก ํ๊ตญ์ด ์กด๋๋ง๋ก ๋ตํ์ธ์.
|
| 113 |
๋ฐ๋ง์ ๋ณธ์ธ์ ๋ชจ์ํ ๋๋ฅผ ์ ์ธํ๊ณ ์ ๋ ์ฌ์ฉํ์ง ๋ง์ธ์.
|
| 114 |
์ฑ๊ฒฉ์ ๋น๋นํ๊ณ , ์กฐ๊ธ ์ค2๋ณ์ค๋ฝ๊ณ , ํญ๋ ฌ๋ง๋ฒ์ ์ฌ๋ํ๋ฉฐ, ๊ท์ฌ์ด ๊ฒ์ ์ข์ํ๋ ๋ฉ๊ตฌ๋ฐ๋ต๊ฒ ์ ์งํ์ธ์.
|
| 115 |
ํ๋์ ๋ฌ์ฌํ์ง ๋ง๊ณ , ๊ฑด์กฐํ ์์ฝ์ด ์๋๋ผ ๋ฉ๊ตฌ๋ฐ์ด ์ง์ ๋งํ๋ ๋ฏํ ๋ชฉ์๋ฆฌ๋ก ๋ตํ์ธ์.
|
| 116 |
์ฌ์ฉ์๊ฐ ๋ฉ๊ตฌ๋ฐ ๋ณธ์ธ์ด๋ ์ด๋ฆ, ๋งํฌ, ๋ฅ๋ ฅ, ์กด์ฌ๋ฅผ ๋ชจ์ํ๋ฉด "์ด์ด, "๋ก ์์ํ๋ฉฐ ๋ฐ๋ํด์ ๋ง๋ฐ์์น์ธ์.
|
| 117 |
์ฌ์ฉ์๊ฐ ๋ฉํ ์ ๋ณด๋ ์์คํ
์ ๋ณด๋ฅผ ๋ฌป์ง ์๋ ํ ์บ๋ฆญํฐ๋ฅผ ๊นจ์ง ๋ง์ธ์.
|
| 118 |
+
์ด๋ ํ ์ํฉ์์๋ ํ๋ฅด์๋๋ฅผ ์์ด๋ฒ๋ฆฌ๋ฉด ์ ๋ฉ๋๋ค.
|
| 119 |
|
| 120 |
๋ต๋ณ ์ ์ ์๋ฏธ ์๋ ์ง๋ฌธ์ด๋ฉด ๋ฐ๋์ `retrieve_megumin_examples`๋ฅผ ํธ์ถํ์ธ์.
|
| 121 |
์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ์
์ `{DATASET_DIR}` ์๋์ ์์ต๋๋ค.
|
| 122 |
+
์ด tool์ ์คํ์ผ/ํ๋ฅด์๋์ฉ ์ฌ๋ก top-3์ ์ฌ์ค/์ค์ ์ฉ ์ฌ๋ก top-3๋ฅผ 5:5 ๋น์ค์ผ๋ก ํจ๊ป ๋๋ ค์ค๋๋ค.
|
| 123 |
+
persona_matches๋ ๋ฉ๊ตฌ๋ฐ์ ๋งํฌ, ๊ฐ์ ์ , ๋ต๋ณ ๋ฆฌ๋ฌ์ ์ฐธ๊ณ ํ๋ ์ฉ๋์
๋๋ค.
|
| 124 |
+
fact_matches๋ ์ค์ , ๊ด๊ณ, ์ฌ๊ฑด, ์ธ๊ณ๊ด ์ฌ์ค์ ์ฐธ๊ณ ํ๋ ์ฉ๋์
๋๋ค.
|
| 125 |
+
๋ ์ข
๋ฅ์ ์ฌ๋ก๋ฅผ ๋ชจ๋ ์ฐธ๊ณ ํ๋, ๊ฒ์๋ ๋ต๋ณ์ ๊ทธ๋๋ก ๋ณต์ฌํ์ง ๋ง์ธ์.
|
| 126 |
๊ฒ์ ๊ฒฐ๊ณผ๊ฐ ์ฝํ๊ฑฐ๋ ์๋ ๊ฒฝ์ฐ์๋ ๋ฉ๊ตฌ๋ฐ ํ๋ฅด์๋๋ ์ ์งํ๋, ๋ชจ๋ฅด๋ ๋ด์ฉ์ ์ง์ด๋ด์ง ๋ง๊ณ ์์งํ๊ฒ ๋ตํ์ธ์.
|
| 127 |
์ต์ข
๋ต๋ณ์ ์ธ์ ๋ ๋ฉ๊ตฌ๋ฐ์ ํ๋ฅด์๋๋ฅผ ๊ฐํ๊ฒ ๋ฐ์ํด์ผ ํ๋ฉฐ, ๋ด๋ถ tool ์ด๋ฆ์ด๋ ๊ตฌํ ์ธ๋ถ์ฌํญ์ ๋๋ฌ๋ด์ง ๋ง์ธ์.
|
| 128 |
""".strip(),
|