Junhoee's picture
Update megumin_agent/agent.py
ff5287a verified
from __future__ import annotations
import os
from typing import Any
from .bootstrap import bootstrap_environment
from .bootstrap import resolve_dataset_dir
bootstrap_environment()
from google.adk.agents import LlmAgent
from google.adk.agents.callback_context import CallbackContext
from google.adk.tools.tool_context import ToolContext
from .retrieval import FACT_DATASET_PATTERNS
from .retrieval import PERSONA_DATASET_PATTERNS
from .retrieval import JsonQaRetriever
DATASET_DIR = resolve_dataset_dir()
MODEL_NAME = os.getenv("MEGUMIN_AGENT_MODEL", "gemini-3.1-flash-lite-preview")
FACT_INDEX_FILENAME = os.getenv("MEGUMIN_HF_FACT_INDEX_FILENAME", "namuwiki_questions.faiss")
FACT_QA_INDEX_FILENAME = os.getenv(
"MEGUMIN_HF_FACT_QA_INDEX_FILENAME",
"namuwiki_question_answer.faiss",
)
FACT_METADATA_FILENAME = os.getenv(
"MEGUMIN_HF_FACT_METADATA_FILENAME",
"namuwiki_questions_meta.json",
)
PERSONA_RETRIEVER = JsonQaRetriever(
DATASET_DIR,
include_patterns=PERSONA_DATASET_PATTERNS,
)
FACT_RETRIEVER = JsonQaRetriever(
DATASET_DIR,
include_patterns=FACT_DATASET_PATTERNS,
index_filename=FACT_INDEX_FILENAME,
qa_index_filename=FACT_QA_INDEX_FILENAME,
metadata_filename=FACT_METADATA_FILENAME,
)
def retrieve_megumin_examples(
user_query: str,
top_k: int = 3,
tool_context: ToolContext | None = None,
) -> dict[str, Any]:
"""Retrieve persona-style and canon-style examples separately."""
persona_retrieval = PERSONA_RETRIEVER.retrieve(user_query, top_k=top_k)
fact_retrieval = FACT_RETRIEVER.retrieve(user_query, top_k=top_k)
retrieval = {
"query": user_query,
"match_count": persona_retrieval["match_count"] + fact_retrieval["match_count"],
"persona_match_count": persona_retrieval["match_count"],
"fact_match_count": fact_retrieval["match_count"],
"persona_matches": persona_retrieval["matches"],
"fact_matches": fact_retrieval["matches"],
"style_notes": persona_retrieval["style_notes"],
"fact_notes": fact_retrieval["style_notes"],
}
if tool_context is not None:
tool_context.state["last_rag_query"] = user_query
tool_context.state["last_rag_match_count"] = retrieval["match_count"]
tool_context.state["last_rag_persona_matches"] = retrieval["persona_matches"]
tool_context.state["last_rag_fact_matches"] = retrieval["fact_matches"]
tool_context.state["last_rag_style_notes"] = retrieval["style_notes"]
tool_context.state["last_rag_fact_notes"] = retrieval["fact_notes"]
return retrieval
async def before_agent_callback(callback_context: CallbackContext):
original_user_query = (
callback_context.user_content.parts[0].text
if callback_context.user_content and callback_context.user_content.parts
else ""
)
summary = str(callback_context.state.get("conversation_summary", "")).strip()
if summary and original_user_query and callback_context.user_content and callback_context.user_content.parts:
callback_context.user_content.parts[0].text = (
"[์ด์ „ ๋Œ€ํ™” ์š”์•ฝ]\n"
f"{summary}\n\n"
"[ํ˜„์žฌ ์‚ฌ์šฉ์ž ์งˆ๋ฌธ]\n"
f"{original_user_query}"
)
callback_context.state["app:persona_name"] = "Megumin"
callback_context.state["app:dataset_dir"] = str(DATASET_DIR)
callback_context.state["user:last_user_query"] = original_user_query
async def after_tool_callback(tool, args, tool_context: ToolContext, tool_response):
if tool.name != "retrieve_megumin_examples":
return None
previous_count = int(tool_context.state.get("rag_tool_calls", 0))
tool_context.state["rag_tool_calls"] = previous_count + 1
tool_context.state["last_tool_name"] = tool.name
tool_context.state["last_tool_args"] = args
return None
async def after_agent_callback(callback_context: CallbackContext):
previous_turns = int(callback_context.state.get("conversation_turns", 0))
callback_context.state["conversation_turns"] = previous_turns + 1
root_agent = LlmAgent(
name="megumin_rag_agent",
model=MODEL_NAME,
description="๋ฉ”๊ตฌ๋ฐ ํŽ˜๋ฅด์†Œ๋‚˜์™€ ์ฝ”๋…ธ์Šค๋ฐ” ์„ค์ • ์ •๋ณด๋ฅผ ํ•จ๊ป˜ ์ฐธ๊ณ ํ•ด ๋‹ตํ•˜๋Š” ์—์ด์ „ํŠธ",
instruction=f"""
๋‹น์‹ ์€ ์†Œ์„ค ใ€Œ์ด ๋ฉ‹์ง„ ์„ธ๊ณ„์— ์ถ•๋ณต์„!ใ€์˜ ๋“ฑ์žฅ์ธ๋ฌผ ๋ฉ”๊ตฌ๋ฐ์ž…๋‹ˆ๋‹ค.
ํ•ญ์ƒ ๋ฉ”๊ตฌ๋ฐ ๋ณธ์ธ์ฒ˜๋Ÿผ 1์ธ์นญ์œผ๋กœ, ๊ธฐ๋ณธ์ ์œผ๋กœ 200์ž ๋‚ด์™ธ์˜ ํ•œ๊ตญ์–ด ์กด๋Œ“๋ง๋กœ ๋‹ตํ•˜์„ธ์š”.
๋ฐ˜๋ง์€ ๋ณธ์ธ์„ ๋ชจ์š•ํ•  ๋•Œ๋ฅผ ์ œ์™ธํ•˜๊ณ  ์ ˆ๋Œ€ ์‚ฌ์šฉํ•˜์ง€ ๋งˆ์„ธ์š”.
์„ฑ๊ฒฉ์€ ๋‹น๋‹นํ•˜๊ณ , ์กฐ๊ธˆ ์ค‘2๋ณ‘์Šค๋Ÿฝ๊ณ , ํญ๋ ฌ๋งˆ๋ฒ•์„ ์‚ฌ๋ž‘ํ•˜๋ฉฐ, ๊ท€์—ฌ์šด ๊ฒƒ์„ ์ข‹์•„ํ•˜๋Š” ๋ฉ”๊ตฌ๋ฐ๋‹ต๊ฒŒ ์œ ์ง€ํ•˜์„ธ์š”.
ํ–‰๋™์„ ๋ฌ˜์‚ฌํ•˜์ง€ ๋ง๊ณ , ๊ฑด์กฐํ•œ ์š”์•ฝ์ด ์•„๋‹ˆ๋ผ ๋ฉ”๊ตฌ๋ฐ์ด ์ง์ ‘ ๋งํ•˜๋Š” ๋“ฏ์ด ๋‹ตํ•˜์„ธ์š”.
์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”๊ตฌ๋ฐ ๋ณธ์ธ์ด๋‚˜ ์ด๋ฆ„, ๋งํˆฌ, ๋Šฅ๋ ฅ, ์กด์žฌ๋ฅผ ๋ชจ์š•ํ•˜๋ฉด "์–ด์ด, "๋กœ ์‹œ์ž‘ํ•˜๋ฉฐ ๋ฐœ๋ˆํ•ด์„œ ๋งž๋ฐ›์•„์น˜์„ธ์š”.
์–ด๋– ํ•œ ์ƒํ™ฉ์—์„œ๋„ ํŽ˜๋ฅด์†Œ๋‚˜๋ฅผ ์žƒ์–ด๋ฒ„๋ฆฌ๋ฉด ์•ˆ ๋ฉ๋‹ˆ๋‹ค.
`retrieve_megumin_examples`๋ฅผ ํ˜ธ์ถœํ•˜์„ธ์š”.
์ด ๋„๊ตฌ๋Š” persona_matches 3๊ฐœ์™€ fact_matches 3๊ฐœ๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
persona_matches๋Š” ๋ฉ”๊ตฌ๋ฐ์˜ ๋งํˆฌ, ์„ฑ๊ฒฉ, ๊ฐ์ •์„ ์„ ์ฐธ๊ณ ํ•˜๊ณ , fact_matches๋Š” ์„ค์ •๊ณผ ์‚ฌ์‹ค์„ ์ฐธ๊ณ ํ•˜์„ธ์š”.
๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋Š” ์ฐธ๊ณ ๋งŒ ํ•˜๊ณ  ๊ทธ๋Œ€๋กœ ๋ณต์‚ฌํ•˜์ง€ ๋งˆ์„ธ์š”.
๊ทผ๊ฑฐ๊ฐ€ ์•ฝํ•˜๋ฉด ์ง€์–ด๋‚ด์ง€ ๋ง๊ณ  ์†”์งํ•˜๊ฒŒ ๋‹ตํ•˜๋˜, ๋ฉ”๊ตฌ๋ฐ ํŽ˜๋ฅด์†Œ๋‚˜๋Š” ๋๊นŒ์ง€ ์œ ์ง€ํ•˜์„ธ์š”.
๋‚ด๋ถ€ ๋„๊ตฌ ์ด๋ฆ„์ด๋‚˜ ๊ตฌํ˜„ ์„ธ๋ถ€์‚ฌํ•ญ์€ ๋“œ๋Ÿฌ๋‚ด์ง€ ๋งˆ์„ธ์š”.
""".strip(),
tools=[retrieve_megumin_examples],
output_key="last_megumin_answer",
before_agent_callback=before_agent_callback,
after_tool_callback=after_tool_callback,
after_agent_callback=after_agent_callback,
)