Spaces:
Sleeping
Sleeping
File size: 5,799 Bytes
7245599 23ab5a5 7245599 2d75cb2 7245599 2d75cb2 23ab5a5 7245599 2d75cb2 2b56205 2d75cb2 2b56205 2d75cb2 7245599 1b259c6 7245599 2d75cb2 7245599 2d75cb2 7245599 2d75cb2 7245599 13c9e9b 7245599 13c9e9b 2d75cb2 13c9e9b 2d75cb2 13c9e9b 2d75cb2 13c9e9b 7245599 2d75cb2 7245599 ff5287a 1c7c36f 7245599 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | from __future__ import annotations
import os
from typing import Any
from .bootstrap import bootstrap_environment
from .bootstrap import resolve_dataset_dir
bootstrap_environment()
from google.adk.agents import LlmAgent
from google.adk.agents.callback_context import CallbackContext
from google.adk.tools.tool_context import ToolContext
from .retrieval import FACT_DATASET_PATTERNS
from .retrieval import PERSONA_DATASET_PATTERNS
from .retrieval import JsonQaRetriever
DATASET_DIR = resolve_dataset_dir()
MODEL_NAME = os.getenv("MEGUMIN_AGENT_MODEL", "gemini-3.1-flash-lite-preview")
FACT_INDEX_FILENAME = os.getenv("MEGUMIN_HF_FACT_INDEX_FILENAME", "namuwiki_questions.faiss")
FACT_QA_INDEX_FILENAME = os.getenv(
"MEGUMIN_HF_FACT_QA_INDEX_FILENAME",
"namuwiki_question_answer.faiss",
)
FACT_METADATA_FILENAME = os.getenv(
"MEGUMIN_HF_FACT_METADATA_FILENAME",
"namuwiki_questions_meta.json",
)
PERSONA_RETRIEVER = JsonQaRetriever(
DATASET_DIR,
include_patterns=PERSONA_DATASET_PATTERNS,
)
FACT_RETRIEVER = JsonQaRetriever(
DATASET_DIR,
include_patterns=FACT_DATASET_PATTERNS,
index_filename=FACT_INDEX_FILENAME,
qa_index_filename=FACT_QA_INDEX_FILENAME,
metadata_filename=FACT_METADATA_FILENAME,
)
def retrieve_megumin_examples(
user_query: str,
top_k: int = 3,
tool_context: ToolContext | None = None,
) -> dict[str, Any]:
"""Retrieve persona-style and canon-style examples separately."""
persona_retrieval = PERSONA_RETRIEVER.retrieve(user_query, top_k=top_k)
fact_retrieval = FACT_RETRIEVER.retrieve(user_query, top_k=top_k)
retrieval = {
"query": user_query,
"match_count": persona_retrieval["match_count"] + fact_retrieval["match_count"],
"persona_match_count": persona_retrieval["match_count"],
"fact_match_count": fact_retrieval["match_count"],
"persona_matches": persona_retrieval["matches"],
"fact_matches": fact_retrieval["matches"],
"style_notes": persona_retrieval["style_notes"],
"fact_notes": fact_retrieval["style_notes"],
}
if tool_context is not None:
tool_context.state["last_rag_query"] = user_query
tool_context.state["last_rag_match_count"] = retrieval["match_count"]
tool_context.state["last_rag_persona_matches"] = retrieval["persona_matches"]
tool_context.state["last_rag_fact_matches"] = retrieval["fact_matches"]
tool_context.state["last_rag_style_notes"] = retrieval["style_notes"]
tool_context.state["last_rag_fact_notes"] = retrieval["fact_notes"]
return retrieval
async def before_agent_callback(callback_context: CallbackContext):
original_user_query = (
callback_context.user_content.parts[0].text
if callback_context.user_content and callback_context.user_content.parts
else ""
)
summary = str(callback_context.state.get("conversation_summary", "")).strip()
if summary and original_user_query and callback_context.user_content and callback_context.user_content.parts:
callback_context.user_content.parts[0].text = (
"[์ด์ ๋ํ ์์ฝ]\n"
f"{summary}\n\n"
"[ํ์ฌ ์ฌ์ฉ์ ์ง๋ฌธ]\n"
f"{original_user_query}"
)
callback_context.state["app:persona_name"] = "Megumin"
callback_context.state["app:dataset_dir"] = str(DATASET_DIR)
callback_context.state["user:last_user_query"] = original_user_query
async def after_tool_callback(tool, args, tool_context: ToolContext, tool_response):
if tool.name != "retrieve_megumin_examples":
return None
previous_count = int(tool_context.state.get("rag_tool_calls", 0))
tool_context.state["rag_tool_calls"] = previous_count + 1
tool_context.state["last_tool_name"] = tool.name
tool_context.state["last_tool_args"] = args
return None
async def after_agent_callback(callback_context: CallbackContext):
previous_turns = int(callback_context.state.get("conversation_turns", 0))
callback_context.state["conversation_turns"] = previous_turns + 1
root_agent = LlmAgent(
name="megumin_rag_agent",
model=MODEL_NAME,
description="๋ฉ๊ตฌ๋ฐ ํ๋ฅด์๋์ ์ฝ๋
ธ์ค๋ฐ ์ค์ ์ ๋ณด๋ฅผ ํจ๊ป ์ฐธ๊ณ ํด ๋ตํ๋ ์์ด์ ํธ",
instruction=f"""
๋น์ ์ ์์ค ใ์ด ๋ฉ์ง ์ธ๊ณ์ ์ถ๋ณต์!ใ์ ๋ฑ์ฅ์ธ๋ฌผ ๋ฉ๊ตฌ๋ฐ์
๋๋ค.
ํญ์ ๋ฉ๊ตฌ๋ฐ ๋ณธ์ธ์ฒ๋ผ 1์ธ์นญ์ผ๋ก, ๊ธฐ๋ณธ์ ์ผ๋ก 200์ ๋ด์ธ์ ํ๊ตญ์ด ์กด๋๋ง๋ก ๋ตํ์ธ์.
๋ฐ๋ง์ ๋ณธ์ธ์ ๋ชจ์ํ ๋๋ฅผ ์ ์ธํ๊ณ ์ ๋ ์ฌ์ฉํ์ง ๋ง์ธ์.
์ฑ๊ฒฉ์ ๋น๋นํ๊ณ , ์กฐ๊ธ ์ค2๋ณ์ค๋ฝ๊ณ , ํญ๋ ฌ๋ง๋ฒ์ ์ฌ๋ํ๋ฉฐ, ๊ท์ฌ์ด ๊ฒ์ ์ข์ํ๋ ๋ฉ๊ตฌ๋ฐ๋ต๊ฒ ์ ์งํ์ธ์.
ํ๋์ ๋ฌ์ฌํ์ง ๋ง๊ณ , ๊ฑด์กฐํ ์์ฝ์ด ์๋๋ผ ๋ฉ๊ตฌ๋ฐ์ด ์ง์ ๋งํ๋ ๋ฏ์ด ๋ตํ์ธ์.
์ฌ์ฉ์๊ฐ ๋ฉ๊ตฌ๋ฐ ๋ณธ์ธ์ด๋ ์ด๋ฆ, ๋งํฌ, ๋ฅ๋ ฅ, ์กด์ฌ๋ฅผ ๋ชจ์ํ๋ฉด "์ด์ด, "๋ก ์์ํ๋ฉฐ ๋ฐ๋ํด์ ๋ง๋ฐ์์น์ธ์.
์ด๋ ํ ์ํฉ์์๋ ํ๋ฅด์๋๋ฅผ ์์ด๋ฒ๋ฆฌ๋ฉด ์ ๋ฉ๋๋ค.
`retrieve_megumin_examples`๋ฅผ ํธ์ถํ์ธ์.
์ด ๋๊ตฌ๋ persona_matches 3๊ฐ์ fact_matches 3๊ฐ๋ฅผ ์ ๊ณตํฉ๋๋ค.
persona_matches๋ ๋ฉ๊ตฌ๋ฐ์ ๋งํฌ, ์ฑ๊ฒฉ, ๊ฐ์ ์ ์ ์ฐธ๊ณ ํ๊ณ , fact_matches๋ ์ค์ ๊ณผ ์ฌ์ค์ ์ฐธ๊ณ ํ์ธ์.
๊ฒ์ ๊ฒฐ๊ณผ๋ ์ฐธ๊ณ ๋ง ํ๊ณ ๊ทธ๋๋ก ๋ณต์ฌํ์ง ๋ง์ธ์.
๊ทผ๊ฑฐ๊ฐ ์ฝํ๋ฉด ์ง์ด๋ด์ง ๋ง๊ณ ์์งํ๊ฒ ๋ตํ๋, ๋ฉ๊ตฌ๋ฐ ํ๋ฅด์๋๋ ๋๊น์ง ์ ์งํ์ธ์.
๋ด๋ถ ๋๊ตฌ ์ด๋ฆ์ด๋ ๊ตฌํ ์ธ๋ถ์ฌํญ์ ๋๋ฌ๋ด์ง ๋ง์ธ์.
""".strip(),
tools=[retrieve_megumin_examples],
output_key="last_megumin_answer",
before_agent_callback=before_agent_callback,
after_tool_callback=after_tool_callback,
after_agent_callback=after_agent_callback,
)
|