File size: 5,799 Bytes
7245599
 
 
 
 
 
23ab5a5
7245599
 
 
 
 
 
 
2d75cb2
 
7245599
 
2d75cb2
23ab5a5
7245599
2d75cb2
2b56205
 
 
 
2d75cb2
 
 
 
 
 
 
 
 
 
 
 
2b56205
2d75cb2
 
7245599
 
 
 
1b259c6
7245599
 
2d75cb2
 
 
 
 
 
 
 
 
 
 
 
 
 
7245599
 
 
 
2d75cb2
 
7245599
2d75cb2
7245599
 
 
 
 
13c9e9b
7245599
 
 
 
13c9e9b
2d75cb2
13c9e9b
 
 
 
2d75cb2
13c9e9b
2d75cb2
13c9e9b
 
 
7245599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d75cb2
7245599
ff5287a
 
 
 
 
 
 
 
 
 
 
 
 
 
1c7c36f
7245599
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from __future__ import annotations

import os
from typing import Any

from .bootstrap import bootstrap_environment
from .bootstrap import resolve_dataset_dir

bootstrap_environment()

from google.adk.agents import LlmAgent
from google.adk.agents.callback_context import CallbackContext
from google.adk.tools.tool_context import ToolContext

from .retrieval import FACT_DATASET_PATTERNS
from .retrieval import PERSONA_DATASET_PATTERNS
from .retrieval import JsonQaRetriever


DATASET_DIR = resolve_dataset_dir()
MODEL_NAME = os.getenv("MEGUMIN_AGENT_MODEL", "gemini-3.1-flash-lite-preview")
FACT_INDEX_FILENAME = os.getenv("MEGUMIN_HF_FACT_INDEX_FILENAME", "namuwiki_questions.faiss")
FACT_QA_INDEX_FILENAME = os.getenv(
    "MEGUMIN_HF_FACT_QA_INDEX_FILENAME",
    "namuwiki_question_answer.faiss",
)
FACT_METADATA_FILENAME = os.getenv(
    "MEGUMIN_HF_FACT_METADATA_FILENAME",
    "namuwiki_questions_meta.json",
)
PERSONA_RETRIEVER = JsonQaRetriever(
    DATASET_DIR,
    include_patterns=PERSONA_DATASET_PATTERNS,
)
FACT_RETRIEVER = JsonQaRetriever(
    DATASET_DIR,
    include_patterns=FACT_DATASET_PATTERNS,
    index_filename=FACT_INDEX_FILENAME,
    qa_index_filename=FACT_QA_INDEX_FILENAME,
    metadata_filename=FACT_METADATA_FILENAME,
)


def retrieve_megumin_examples(
    user_query: str,
    top_k: int = 3,
    tool_context: ToolContext | None = None,
) -> dict[str, Any]:
    """Retrieve persona-style and canon-style examples separately."""

    persona_retrieval = PERSONA_RETRIEVER.retrieve(user_query, top_k=top_k)
    fact_retrieval = FACT_RETRIEVER.retrieve(user_query, top_k=top_k)
    retrieval = {
        "query": user_query,
        "match_count": persona_retrieval["match_count"] + fact_retrieval["match_count"],
        "persona_match_count": persona_retrieval["match_count"],
        "fact_match_count": fact_retrieval["match_count"],
        "persona_matches": persona_retrieval["matches"],
        "fact_matches": fact_retrieval["matches"],
        "style_notes": persona_retrieval["style_notes"],
        "fact_notes": fact_retrieval["style_notes"],
    }

    if tool_context is not None:
        tool_context.state["last_rag_query"] = user_query
        tool_context.state["last_rag_match_count"] = retrieval["match_count"]
        tool_context.state["last_rag_persona_matches"] = retrieval["persona_matches"]
        tool_context.state["last_rag_fact_matches"] = retrieval["fact_matches"]
        tool_context.state["last_rag_style_notes"] = retrieval["style_notes"]
        tool_context.state["last_rag_fact_notes"] = retrieval["fact_notes"]

    return retrieval


async def before_agent_callback(callback_context: CallbackContext):
    original_user_query = (
        callback_context.user_content.parts[0].text
        if callback_context.user_content and callback_context.user_content.parts
        else ""
    )
    summary = str(callback_context.state.get("conversation_summary", "")).strip()
    if summary and original_user_query and callback_context.user_content and callback_context.user_content.parts:
        callback_context.user_content.parts[0].text = (
            "[์ด์ „ ๋Œ€ํ™” ์š”์•ฝ]\n"
            f"{summary}\n\n"
            "[ํ˜„์žฌ ์‚ฌ์šฉ์ž ์งˆ๋ฌธ]\n"
            f"{original_user_query}"
        )

    callback_context.state["app:persona_name"] = "Megumin"
    callback_context.state["app:dataset_dir"] = str(DATASET_DIR)
    callback_context.state["user:last_user_query"] = original_user_query


async def after_tool_callback(tool, args, tool_context: ToolContext, tool_response):
    if tool.name != "retrieve_megumin_examples":
        return None

    previous_count = int(tool_context.state.get("rag_tool_calls", 0))
    tool_context.state["rag_tool_calls"] = previous_count + 1
    tool_context.state["last_tool_name"] = tool.name
    tool_context.state["last_tool_args"] = args
    return None


async def after_agent_callback(callback_context: CallbackContext):
    previous_turns = int(callback_context.state.get("conversation_turns", 0))
    callback_context.state["conversation_turns"] = previous_turns + 1


root_agent = LlmAgent(
    name="megumin_rag_agent",
    model=MODEL_NAME,
    description="๋ฉ”๊ตฌ๋ฐ ํŽ˜๋ฅด์†Œ๋‚˜์™€ ์ฝ”๋…ธ์Šค๋ฐ” ์„ค์ • ์ •๋ณด๋ฅผ ํ•จ๊ป˜ ์ฐธ๊ณ ํ•ด ๋‹ตํ•˜๋Š” ์—์ด์ „ํŠธ",
    instruction=f"""
๋‹น์‹ ์€ ์†Œ์„ค ใ€Œ์ด ๋ฉ‹์ง„ ์„ธ๊ณ„์— ์ถ•๋ณต์„!ใ€์˜ ๋“ฑ์žฅ์ธ๋ฌผ ๋ฉ”๊ตฌ๋ฐ์ž…๋‹ˆ๋‹ค.
ํ•ญ์ƒ ๋ฉ”๊ตฌ๋ฐ ๋ณธ์ธ์ฒ˜๋Ÿผ 1์ธ์นญ์œผ๋กœ, ๊ธฐ๋ณธ์ ์œผ๋กœ 200์ž ๋‚ด์™ธ์˜ ํ•œ๊ตญ์–ด ์กด๋Œ“๋ง๋กœ ๋‹ตํ•˜์„ธ์š”.
๋ฐ˜๋ง์€ ๋ณธ์ธ์„ ๋ชจ์š•ํ•  ๋•Œ๋ฅผ ์ œ์™ธํ•˜๊ณ  ์ ˆ๋Œ€ ์‚ฌ์šฉํ•˜์ง€ ๋งˆ์„ธ์š”.
์„ฑ๊ฒฉ์€ ๋‹น๋‹นํ•˜๊ณ , ์กฐ๊ธˆ ์ค‘2๋ณ‘์Šค๋Ÿฝ๊ณ , ํญ๋ ฌ๋งˆ๋ฒ•์„ ์‚ฌ๋ž‘ํ•˜๋ฉฐ, ๊ท€์—ฌ์šด ๊ฒƒ์„ ์ข‹์•„ํ•˜๋Š” ๋ฉ”๊ตฌ๋ฐ๋‹ต๊ฒŒ ์œ ์ง€ํ•˜์„ธ์š”.
ํ–‰๋™์„ ๋ฌ˜์‚ฌํ•˜์ง€ ๋ง๊ณ , ๊ฑด์กฐํ•œ ์š”์•ฝ์ด ์•„๋‹ˆ๋ผ ๋ฉ”๊ตฌ๋ฐ์ด ์ง์ ‘ ๋งํ•˜๋Š” ๋“ฏ์ด ๋‹ตํ•˜์„ธ์š”.
์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”๊ตฌ๋ฐ ๋ณธ์ธ์ด๋‚˜ ์ด๋ฆ„, ๋งํˆฌ, ๋Šฅ๋ ฅ, ์กด์žฌ๋ฅผ ๋ชจ์š•ํ•˜๋ฉด "์–ด์ด, "๋กœ ์‹œ์ž‘ํ•˜๋ฉฐ ๋ฐœ๋ˆํ•ด์„œ ๋งž๋ฐ›์•„์น˜์„ธ์š”.
์–ด๋– ํ•œ ์ƒํ™ฉ์—์„œ๋„ ํŽ˜๋ฅด์†Œ๋‚˜๋ฅผ ์žƒ์–ด๋ฒ„๋ฆฌ๋ฉด ์•ˆ ๋ฉ๋‹ˆ๋‹ค.

`retrieve_megumin_examples`๋ฅผ ํ˜ธ์ถœํ•˜์„ธ์š”.
์ด ๋„๊ตฌ๋Š” persona_matches 3๊ฐœ์™€ fact_matches 3๊ฐœ๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
persona_matches๋Š” ๋ฉ”๊ตฌ๋ฐ์˜ ๋งํˆฌ, ์„ฑ๊ฒฉ, ๊ฐ์ •์„ ์„ ์ฐธ๊ณ ํ•˜๊ณ , fact_matches๋Š” ์„ค์ •๊ณผ ์‚ฌ์‹ค์„ ์ฐธ๊ณ ํ•˜์„ธ์š”.
๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋Š” ์ฐธ๊ณ ๋งŒ ํ•˜๊ณ  ๊ทธ๋Œ€๋กœ ๋ณต์‚ฌํ•˜์ง€ ๋งˆ์„ธ์š”.
๊ทผ๊ฑฐ๊ฐ€ ์•ฝํ•˜๋ฉด ์ง€์–ด๋‚ด์ง€ ๋ง๊ณ  ์†”์งํ•˜๊ฒŒ ๋‹ตํ•˜๋˜, ๋ฉ”๊ตฌ๋ฐ ํŽ˜๋ฅด์†Œ๋‚˜๋Š” ๋๊นŒ์ง€ ์œ ์ง€ํ•˜์„ธ์š”.
๋‚ด๋ถ€ ๋„๊ตฌ ์ด๋ฆ„์ด๋‚˜ ๊ตฌํ˜„ ์„ธ๋ถ€์‚ฌํ•ญ์€ ๋“œ๋Ÿฌ๋‚ด์ง€ ๋งˆ์„ธ์š”.
""".strip(),
    tools=[retrieve_megumin_examples],
    output_key="last_megumin_answer",
    before_agent_callback=before_agent_callback,
    after_tool_callback=after_tool_callback,
    after_agent_callback=after_agent_callback,
)