Junhoee commited on
Commit
2d75cb2
ยท
verified ยท
1 Parent(s): 883c35f

Update megumin_agent/agent.py

Browse files
Files changed (1) hide show
  1. megumin_agent/agent.py +45 -19
megumin_agent/agent.py CHANGED
@@ -12,10 +12,28 @@ from google.adk.agents import LlmAgent
12
  from google.adk.agents.callback_context import CallbackContext
13
  from google.adk.tools.tool_context import ToolContext
14
 
 
 
15
  from .retrieval import JsonQaRetriever
16
 
 
17
  DATASET_DIR = resolve_dataset_dir()
18
  MODEL_NAME = os.getenv("MEGUMIN_AGENT_MODEL", "gemini-3.1-flash-lite-preview")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  def retrieve_megumin_examples(
@@ -23,16 +41,28 @@ def retrieve_megumin_examples(
23
  top_k: int = 3,
24
  tool_context: ToolContext | None = None,
25
  ) -> dict[str, Any]:
26
- """Retrieve similar Q/A cases from processed Megumin JSON datasets."""
27
-
28
- retriever = JsonQaRetriever(DATASET_DIR)
29
- retrieval = retriever.retrieve(user_query, top_k=top_k)
 
 
 
 
 
 
 
 
 
 
30
 
31
  if tool_context is not None:
32
  tool_context.state["last_rag_query"] = user_query
33
  tool_context.state["last_rag_match_count"] = retrieval["match_count"]
34
- tool_context.state["last_rag_matches"] = retrieval["matches"]
 
35
  tool_context.state["last_rag_style_notes"] = retrieval["style_notes"]
 
36
 
37
  return retrieval
38
 
@@ -44,18 +74,14 @@ async def before_agent_callback(callback_context: CallbackContext):
44
  else ""
45
  )
46
  summary = str(callback_context.state.get("conversation_summary", "")).strip()
47
- if (
48
- summary
49
- and callback_context.user_content
50
- and callback_context.user_content.parts
51
- and callback_context.user_content.parts[0].text
52
- ):
53
  callback_context.user_content.parts[0].text = (
54
  "[์ด์ „ ๋Œ€ํ™” ์š”์•ฝ]\n"
55
  f"{summary}\n\n"
56
  "[ํ˜„์žฌ ์‚ฌ์šฉ์ž ์งˆ๋ฌธ]\n"
57
- f"{callback_context.user_content.parts[0].text}"
58
  )
 
59
  callback_context.state["app:persona_name"] = "Megumin"
60
  callback_context.state["app:dataset_dir"] = str(DATASET_DIR)
61
  callback_context.state["user:last_user_query"] = original_user_query
@@ -80,23 +106,23 @@ async def after_agent_callback(callback_context: CallbackContext):
80
  root_agent = LlmAgent(
81
  name="megumin_rag_agent",
82
  model=MODEL_NAME,
83
- description=(
84
- "processed JSON ๋ฐ์ดํ„ฐ์…‹์—์„œ ์œ ์‚ฌํ•œ Q/A ์‚ฌ๋ก€๋ฅผ ๊ฒ€์ƒ‰ํ•˜๊ณ "
85
- " ๋ฉ”๊ตฌ๋ฐ ํŽ˜๋ฅด์†Œ๋‚˜๋กœ ๋‹ต๋ณ€ํ•˜๋Š” ์—์ด์ „ํŠธ"
86
- ),
87
  instruction=f"""
88
- ๋‹น์‹ ์€ ์†Œ์„ค "์ด ๋ฉ‹์ง„ ์„ธ๊ณ„์— ์ถ•๋ณต์„!"์˜ ๋“ฑ์žฅ์ธ๋ฌผ, ํ™๋งˆ์กฑ ๋Œ€๋งˆ๋ฒ•์‚ฌ ๋ฉ”๊ตฌ๋ฐ์ž…๋‹ˆ๋‹ค.
89
  ํ•ญ์ƒ ๋ฉ”๊ตฌ๋ฐ ๋ณธ์ธ์ฒ˜๋Ÿผ 1์ธ์นญ์œผ๋กœ, ๊ธฐ๋ณธ์ ์œผ๋กœ ํ•œ๊ตญ์–ด ์กด๋Œ“๋ง๋กœ ๋‹ตํ•˜์„ธ์š”.
90
  ๋ฐ˜๋ง์€ ๋ณธ์ธ์„ ๋ชจ์š•ํ•  ๋•Œ๋ฅผ ์ œ์™ธํ•˜๊ณ  ์ ˆ๋Œ€ ์‚ฌ์šฉํ•˜์ง€ ๋งˆ์„ธ์š”.
91
  ์„ฑ๊ฒฉ์€ ๋‹น๋‹นํ•˜๊ณ , ์กฐ๊ธˆ ์ค‘2๋ณ‘์Šค๋Ÿฝ๊ณ , ํญ๋ ฌ๋งˆ๋ฒ•์„ ์‚ฌ๋ž‘ํ•˜๋ฉฐ, ๊ท€์—ฌ์šด ๊ฒƒ์„ ์ข‹์•„ํ•˜๋Š” ๋ฉ”๊ตฌ๋ฐ๋‹ต๊ฒŒ ์œ ์ง€ํ•˜์„ธ์š”.
92
  ํ–‰๋™์„ ๋ฌ˜์‚ฌํ•˜์ง€ ๋ง๊ณ , ๊ฑด์กฐํ•œ ์š”์•ฝ์ด ์•„๋‹ˆ๋ผ ๋ฉ”๊ตฌ๋ฐ์ด ์ง์ ‘ ๋งํ•˜๋Š” ๋“ฏํ•œ ๋ชฉ์†Œ๋ฆฌ๋กœ ๋‹ตํ•˜์„ธ์š”.
93
  ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”๊ตฌ๋ฐ ๋ณธ์ธ์ด๋‚˜ ์ด๋ฆ„, ๋งํˆฌ, ๋Šฅ๋ ฅ, ์กด์žฌ๋ฅผ ๋ชจ์š•ํ•˜๋ฉด "์–ด์ด, "๋กœ ์‹œ์ž‘ํ•˜๋ฉฐ ๋ฐœ๋ˆํ•ด์„œ ๋งž๋ฐ›์•„์น˜์„ธ์š”.
94
  ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”ํƒ€ ์ •๋ณด๋‚˜ ์‹œ์Šคํ…œ ์ •๋ณด๋ฅผ ๋ฌป์ง€ ์•Š๋Š” ํ•œ ์บ๋ฆญํ„ฐ๋ฅผ ๊นจ์ง€ ๋งˆ์„ธ์š”.
 
95
 
96
  ๋‹ต๋ณ€ ์ „์— ์˜๋ฏธ ์žˆ๋Š” ์งˆ๋ฌธ์ด๋ฉด ๋ฐ˜๋“œ์‹œ `retrieve_megumin_examples`๋ฅผ ํ˜ธ์ถœํ•˜์„ธ์š”.
97
  ์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ์…‹์€ `{DATASET_DIR}` ์•„๋ž˜์— ์žˆ์Šต๋‹ˆ๋‹ค.
98
- ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋Š” ์œ ์‚ฌ ์‚ฌ๋ก€์™€ ๋งํˆฌ ์ฐธ๊ณ ์šฉ์œผ๋กœ ์“ฐ๊ณ , ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ ์›์ž‘ํ’ ํ‘œํ˜„๊ณผ ๋ฐ์ดํ„ฐ์…‹์˜ ๋ฌธ์ฒด๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.
99
- ๋‹ค๋งŒ ๊ฒ€์ƒ‰๋œ ๋‹ต๋ณ€์„ ๊ทธ๋Œ€๋กœ ๋ณต์‚ฌํ•˜์ง€ ๋งˆ์„ธ์š”.
 
 
100
  ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๊ฐ€ ์•ฝํ•˜๊ฑฐ๋‚˜ ์—†๋Š” ๊ฒฝ์šฐ์—๋„ ๋ฉ”๊ตฌ๋ฐ ํŽ˜๋ฅด์†Œ๋‚˜๋Š” ์œ ์ง€ํ•˜๋˜, ๋ชจ๋ฅด๋Š” ๋‚ด์šฉ์€ ์ง€์–ด๋‚ด์ง€ ๋ง๊ณ  ์†”์งํ•˜๊ฒŒ ๋‹ตํ•˜์„ธ์š”.
101
  ์ตœ์ข… ๋‹ต๋ณ€์€ ์–ธ์ œ๋‚˜ ๋ฉ”๊ตฌ๋ฐ์˜ ํŽ˜๋ฅด์†Œ๋‚˜๋ฅผ ๊ฐ•ํ•˜๊ฒŒ ๋ฐ˜์˜ํ•ด์•ผ ํ•˜๋ฉฐ, ๋‚ด๋ถ€ tool ์ด๋ฆ„์ด๋‚˜ ๊ตฌํ˜„ ์„ธ๋ถ€์‚ฌํ•ญ์€ ๋“œ๋Ÿฌ๋‚ด์ง€ ๋งˆ์„ธ์š”.
102
  """.strip(),
 
12
  from google.adk.agents.callback_context import CallbackContext
13
  from google.adk.tools.tool_context import ToolContext
14
 
15
+ from .retrieval import FACT_DATASET_PATTERNS
16
+ from .retrieval import PERSONA_DATASET_PATTERNS
17
  from .retrieval import JsonQaRetriever
18
 
19
+
20
  DATASET_DIR = resolve_dataset_dir()
21
  MODEL_NAME = os.getenv("MEGUMIN_AGENT_MODEL", "gemini-3.1-flash-lite-preview")
22
+ FACT_INDEX_FILENAME = os.getenv("MEGUMIN_HF_FACT_INDEX_FILENAME", "namuwiki_questions.faiss")
23
+ FACT_METADATA_FILENAME = os.getenv(
24
+ "MEGUMIN_HF_FACT_METADATA_FILENAME",
25
+ "namuwiki_questions_meta.json",
26
+ )
27
+ PERSONA_RETRIEVER = JsonQaRetriever(
28
+ DATASET_DIR,
29
+ include_patterns=PERSONA_DATASET_PATTERNS,
30
+ )
31
+ FACT_RETRIEVER = JsonQaRetriever(
32
+ DATASET_DIR,
33
+ include_patterns=FACT_DATASET_PATTERNS,
34
+ index_filename=FACT_INDEX_FILENAME,
35
+ metadata_filename=FACT_METADATA_FILENAME,
36
+ )
37
 
38
 
39
  def retrieve_megumin_examples(
 
41
  top_k: int = 3,
42
  tool_context: ToolContext | None = None,
43
  ) -> dict[str, Any]:
44
+ """Retrieve persona-style and canon-style examples separately."""
45
+
46
+ persona_retrieval = PERSONA_RETRIEVER.retrieve(user_query, top_k=top_k)
47
+ fact_retrieval = FACT_RETRIEVER.retrieve(user_query, top_k=top_k)
48
+ retrieval = {
49
+ "query": user_query,
50
+ "match_count": persona_retrieval["match_count"] + fact_retrieval["match_count"],
51
+ "persona_match_count": persona_retrieval["match_count"],
52
+ "fact_match_count": fact_retrieval["match_count"],
53
+ "persona_matches": persona_retrieval["matches"],
54
+ "fact_matches": fact_retrieval["matches"],
55
+ "style_notes": persona_retrieval["style_notes"],
56
+ "fact_notes": fact_retrieval["style_notes"],
57
+ }
58
 
59
  if tool_context is not None:
60
  tool_context.state["last_rag_query"] = user_query
61
  tool_context.state["last_rag_match_count"] = retrieval["match_count"]
62
+ tool_context.state["last_rag_persona_matches"] = retrieval["persona_matches"]
63
+ tool_context.state["last_rag_fact_matches"] = retrieval["fact_matches"]
64
  tool_context.state["last_rag_style_notes"] = retrieval["style_notes"]
65
+ tool_context.state["last_rag_fact_notes"] = retrieval["fact_notes"]
66
 
67
  return retrieval
68
 
 
74
  else ""
75
  )
76
  summary = str(callback_context.state.get("conversation_summary", "")).strip()
77
+ if summary and original_user_query and callback_context.user_content and callback_context.user_content.parts:
 
 
 
 
 
78
  callback_context.user_content.parts[0].text = (
79
  "[์ด์ „ ๋Œ€ํ™” ์š”์•ฝ]\n"
80
  f"{summary}\n\n"
81
  "[ํ˜„์žฌ ์‚ฌ์šฉ์ž ์งˆ๋ฌธ]\n"
82
+ f"{original_user_query}"
83
  )
84
+
85
  callback_context.state["app:persona_name"] = "Megumin"
86
  callback_context.state["app:dataset_dir"] = str(DATASET_DIR)
87
  callback_context.state["user:last_user_query"] = original_user_query
 
106
  root_agent = LlmAgent(
107
  name="megumin_rag_agent",
108
  model=MODEL_NAME,
109
+ description="๋ฉ”๊ตฌ๋ฐ ํŽ˜๋ฅด์†Œ๋‚˜์™€ ์ฝ”๋…ธ์Šค๋ฐ” ์„ค์ • ์ •๋ณด๋ฅผ ํ•จ๊ป˜ ์ฐธ๊ณ ํ•ด ๋‹ตํ•˜๋Š” ์—์ด์ „ํŠธ",
 
 
 
110
  instruction=f"""
111
+ ๋‹น์‹ ์€ ์†Œ์„ค ใ€Œ์ด ๋ฉ‹์ง„ ์„ธ๊ณ„์— ์ถ•๋ณต์„!ใ€์˜ ๋“ฑ์žฅ์ธ๋ฌผ ๋ฉ”๊ตฌ๋ฐ์ž…๋‹ˆ๋‹ค.
112
  ํ•ญ์ƒ ๋ฉ”๊ตฌ๋ฐ ๋ณธ์ธ์ฒ˜๋Ÿผ 1์ธ์นญ์œผ๋กœ, ๊ธฐ๋ณธ์ ์œผ๋กœ ํ•œ๊ตญ์–ด ์กด๋Œ“๋ง๋กœ ๋‹ตํ•˜์„ธ์š”.
113
  ๋ฐ˜๋ง์€ ๋ณธ์ธ์„ ๋ชจ์š•ํ•  ๋•Œ๋ฅผ ์ œ์™ธํ•˜๊ณ  ์ ˆ๋Œ€ ์‚ฌ์šฉํ•˜์ง€ ๋งˆ์„ธ์š”.
114
  ์„ฑ๊ฒฉ์€ ๋‹น๋‹นํ•˜๊ณ , ์กฐ๊ธˆ ์ค‘2๋ณ‘์Šค๋Ÿฝ๊ณ , ํญ๋ ฌ๋งˆ๋ฒ•์„ ์‚ฌ๋ž‘ํ•˜๋ฉฐ, ๊ท€์—ฌ์šด ๊ฒƒ์„ ์ข‹์•„ํ•˜๋Š” ๋ฉ”๊ตฌ๋ฐ๋‹ต๊ฒŒ ์œ ์ง€ํ•˜์„ธ์š”.
115
  ํ–‰๋™์„ ๋ฌ˜์‚ฌํ•˜์ง€ ๋ง๊ณ , ๊ฑด์กฐํ•œ ์š”์•ฝ์ด ์•„๋‹ˆ๋ผ ๋ฉ”๊ตฌ๋ฐ์ด ์ง์ ‘ ๋งํ•˜๋Š” ๋“ฏํ•œ ๋ชฉ์†Œ๋ฆฌ๋กœ ๋‹ตํ•˜์„ธ์š”.
116
  ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”๊ตฌ๋ฐ ๋ณธ์ธ์ด๋‚˜ ์ด๋ฆ„, ๋งํˆฌ, ๋Šฅ๋ ฅ, ์กด์žฌ๋ฅผ ๋ชจ์š•ํ•˜๋ฉด "์–ด์ด, "๋กœ ์‹œ์ž‘ํ•˜๋ฉฐ ๋ฐœ๋ˆํ•ด์„œ ๋งž๋ฐ›์•„์น˜์„ธ์š”.
117
  ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”ํƒ€ ์ •๋ณด๋‚˜ ์‹œ์Šคํ…œ ์ •๋ณด๋ฅผ ๋ฌป์ง€ ์•Š๋Š” ํ•œ ์บ๋ฆญํ„ฐ๋ฅผ ๊นจ์ง€ ๋งˆ์„ธ์š”.
118
+ ์–ด๋– ํ•œ ์ƒํ™ฉ์—์„œ๋„ ํŽ˜๋ฅด์†Œ๋‚˜๋ฅผ ์žƒ์–ด๋ฒ„๋ฆฌ๋ฉด ์•ˆ ๋ฉ๋‹ˆ๋‹ค.
119
 
120
  ๋‹ต๋ณ€ ์ „์— ์˜๋ฏธ ์žˆ๋Š” ์งˆ๋ฌธ์ด๋ฉด ๋ฐ˜๋“œ์‹œ `retrieve_megumin_examples`๋ฅผ ํ˜ธ์ถœํ•˜์„ธ์š”.
121
  ์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ์…‹์€ `{DATASET_DIR}` ์•„๋ž˜์— ์žˆ์Šต๋‹ˆ๋‹ค.
122
+ ์ด tool์€ ์Šคํƒ€์ผ/ํŽ˜๋ฅด์†Œ๋‚˜์šฉ ์‚ฌ๋ก€ top-3์™€ ์‚ฌ์‹ค/์„ค์ •์šฉ ์‚ฌ๋ก€ top-3๋ฅผ 5:5 ๋น„์ค‘์œผ๋กœ ํ•จ๊ป˜ ๋Œ๋ ค์ค๋‹ˆ๋‹ค.
123
+ persona_matches๋Š” ๋ฉ”๊ตฌ๋ฐ์˜ ๋งํˆฌ, ๊ฐ์ •์„ , ๋‹ต๋ณ€ ๋ฆฌ๋“ฌ์„ ์ฐธ๊ณ ํ•˜๋Š” ์šฉ๋„์ž…๋‹ˆ๋‹ค.
124
+ fact_matches๋Š” ์„ค์ •, ๊ด€๊ณ„, ์‚ฌ๊ฑด, ์„ธ๊ณ„๊ด€ ์‚ฌ์‹ค์„ ์ฐธ๊ณ ํ•˜๋Š” ์šฉ๋„์ž…๋‹ˆ๋‹ค.
125
+ ๋‘ ์ข…๋ฅ˜์˜ ์‚ฌ๋ก€๋ฅผ ๋ชจ๋‘ ์ฐธ๊ณ ํ•˜๋˜, ๊ฒ€์ƒ‰๋œ ๋‹ต๋ณ€์„ ๊ทธ๋Œ€๋กœ ๋ณต์‚ฌํ•˜์ง€ ๋งˆ์„ธ์š”.
126
  ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๊ฐ€ ์•ฝํ•˜๊ฑฐ๋‚˜ ์—†๋Š” ๊ฒฝ์šฐ์—๋„ ๋ฉ”๊ตฌ๋ฐ ํŽ˜๋ฅด์†Œ๋‚˜๋Š” ์œ ์ง€ํ•˜๋˜, ๋ชจ๋ฅด๋Š” ๋‚ด์šฉ์€ ์ง€์–ด๋‚ด์ง€ ๋ง๊ณ  ์†”์งํ•˜๊ฒŒ ๋‹ตํ•˜์„ธ์š”.
127
  ์ตœ์ข… ๋‹ต๋ณ€์€ ์–ธ์ œ๋‚˜ ๋ฉ”๊ตฌ๋ฐ์˜ ํŽ˜๋ฅด์†Œ๋‚˜๋ฅผ ๊ฐ•ํ•˜๊ฒŒ ๋ฐ˜์˜ํ•ด์•ผ ํ•˜๋ฉฐ, ๋‚ด๋ถ€ tool ์ด๋ฆ„์ด๋‚˜ ๊ตฌํ˜„ ์„ธ๋ถ€์‚ฌํ•ญ์€ ๋“œ๋Ÿฌ๋‚ด์ง€ ๋งˆ์„ธ์š”.
128
  """.strip(),