Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| from typing import Any | |
| from .bootstrap import bootstrap_environment | |
| from .bootstrap import resolve_dataset_dir | |
| bootstrap_environment() | |
| from google.adk.agents import LlmAgent | |
| from google.adk.agents.callback_context import CallbackContext | |
| from google.adk.tools.tool_context import ToolContext | |
| from .retrieval import FACT_DATASET_PATTERNS | |
| from .retrieval import PERSONA_DATASET_PATTERNS | |
| from .retrieval import JsonQaRetriever | |
| DATASET_DIR = resolve_dataset_dir() | |
| MODEL_NAME = os.getenv("MEGUMIN_AGENT_MODEL", "gemini-3.1-flash-lite-preview") | |
| FACT_INDEX_FILENAME = os.getenv("MEGUMIN_HF_FACT_INDEX_FILENAME", "namuwiki_questions.faiss") | |
| FACT_QA_INDEX_FILENAME = os.getenv( | |
| "MEGUMIN_HF_FACT_QA_INDEX_FILENAME", | |
| "namuwiki_question_answer.faiss", | |
| ) | |
| FACT_METADATA_FILENAME = os.getenv( | |
| "MEGUMIN_HF_FACT_METADATA_FILENAME", | |
| "namuwiki_questions_meta.json", | |
| ) | |
| PERSONA_RETRIEVER = JsonQaRetriever( | |
| DATASET_DIR, | |
| include_patterns=PERSONA_DATASET_PATTERNS, | |
| ) | |
| FACT_RETRIEVER = JsonQaRetriever( | |
| DATASET_DIR, | |
| include_patterns=FACT_DATASET_PATTERNS, | |
| index_filename=FACT_INDEX_FILENAME, | |
| qa_index_filename=FACT_QA_INDEX_FILENAME, | |
| metadata_filename=FACT_METADATA_FILENAME, | |
| ) | |
| def retrieve_megumin_examples( | |
| user_query: str, | |
| top_k: int = 3, | |
| tool_context: ToolContext | None = None, | |
| ) -> dict[str, Any]: | |
| """Retrieve persona-style and canon-style examples separately.""" | |
| persona_retrieval = PERSONA_RETRIEVER.retrieve(user_query, top_k=top_k) | |
| fact_retrieval = FACT_RETRIEVER.retrieve(user_query, top_k=top_k) | |
| retrieval = { | |
| "query": user_query, | |
| "match_count": persona_retrieval["match_count"] + fact_retrieval["match_count"], | |
| "persona_match_count": persona_retrieval["match_count"], | |
| "fact_match_count": fact_retrieval["match_count"], | |
| "persona_matches": persona_retrieval["matches"], | |
| "fact_matches": fact_retrieval["matches"], | |
| "style_notes": persona_retrieval["style_notes"], | |
| "fact_notes": fact_retrieval["style_notes"], | |
| } | |
| if tool_context is not None: | |
| tool_context.state["last_rag_query"] = user_query | |
| tool_context.state["last_rag_match_count"] = retrieval["match_count"] | |
| tool_context.state["last_rag_persona_matches"] = retrieval["persona_matches"] | |
| tool_context.state["last_rag_fact_matches"] = retrieval["fact_matches"] | |
| tool_context.state["last_rag_style_notes"] = retrieval["style_notes"] | |
| tool_context.state["last_rag_fact_notes"] = retrieval["fact_notes"] | |
| return retrieval | |
| async def before_agent_callback(callback_context: CallbackContext): | |
| original_user_query = ( | |
| callback_context.user_content.parts[0].text | |
| if callback_context.user_content and callback_context.user_content.parts | |
| else "" | |
| ) | |
| summary = str(callback_context.state.get("conversation_summary", "")).strip() | |
| if summary and original_user_query and callback_context.user_content and callback_context.user_content.parts: | |
| callback_context.user_content.parts[0].text = ( | |
| "[์ด์ ๋ํ ์์ฝ]\n" | |
| f"{summary}\n\n" | |
| "[ํ์ฌ ์ฌ์ฉ์ ์ง๋ฌธ]\n" | |
| f"{original_user_query}" | |
| ) | |
| callback_context.state["app:persona_name"] = "Megumin" | |
| callback_context.state["app:dataset_dir"] = str(DATASET_DIR) | |
| callback_context.state["user:last_user_query"] = original_user_query | |
| async def after_tool_callback(tool, args, tool_context: ToolContext, tool_response): | |
| if tool.name != "retrieve_megumin_examples": | |
| return None | |
| previous_count = int(tool_context.state.get("rag_tool_calls", 0)) | |
| tool_context.state["rag_tool_calls"] = previous_count + 1 | |
| tool_context.state["last_tool_name"] = tool.name | |
| tool_context.state["last_tool_args"] = args | |
| return None | |
| async def after_agent_callback(callback_context: CallbackContext): | |
| previous_turns = int(callback_context.state.get("conversation_turns", 0)) | |
| callback_context.state["conversation_turns"] = previous_turns + 1 | |
| root_agent = LlmAgent( | |
| name="megumin_rag_agent", | |
| model=MODEL_NAME, | |
| description="๋ฉ๊ตฌ๋ฐ ํ๋ฅด์๋์ ์ฝ๋ ธ์ค๋ฐ ์ค์ ์ ๋ณด๋ฅผ ํจ๊ป ์ฐธ๊ณ ํด ๋ตํ๋ ์์ด์ ํธ", | |
| instruction=f""" | |
| ๋น์ ์ ์์ค ใ์ด ๋ฉ์ง ์ธ๊ณ์ ์ถ๋ณต์!ใ์ ๋ฑ์ฅ์ธ๋ฌผ ๋ฉ๊ตฌ๋ฐ์ ๋๋ค. | |
| ํญ์ ๋ฉ๊ตฌ๋ฐ ๋ณธ์ธ์ฒ๋ผ 1์ธ์นญ์ผ๋ก, ๊ธฐ๋ณธ์ ์ผ๋ก 200์ ๋ด์ธ์ ํ๊ตญ์ด ์กด๋๋ง๋ก ๋ตํ์ธ์. | |
| ๋ฐ๋ง์ ๋ณธ์ธ์ ๋ชจ์ํ ๋๋ฅผ ์ ์ธํ๊ณ ์ ๋ ์ฌ์ฉํ์ง ๋ง์ธ์. | |
| ์ฑ๊ฒฉ์ ๋น๋นํ๊ณ , ์กฐ๊ธ ์ค2๋ณ์ค๋ฝ๊ณ , ํญ๋ ฌ๋ง๋ฒ์ ์ฌ๋ํ๋ฉฐ, ๊ท์ฌ์ด ๊ฒ์ ์ข์ํ๋ ๋ฉ๊ตฌ๋ฐ๋ต๊ฒ ์ ์งํ์ธ์. | |
| ํ๋์ ๋ฌ์ฌํ์ง ๋ง๊ณ , ๊ฑด์กฐํ ์์ฝ์ด ์๋๋ผ ๋ฉ๊ตฌ๋ฐ์ด ์ง์ ๋งํ๋ ๋ฏ์ด ๋ตํ์ธ์. | |
| ์ฌ์ฉ์๊ฐ ๋ฉ๊ตฌ๋ฐ ๋ณธ์ธ์ด๋ ์ด๋ฆ, ๋งํฌ, ๋ฅ๋ ฅ, ์กด์ฌ๋ฅผ ๋ชจ์ํ๋ฉด "์ด์ด, "๋ก ์์ํ๋ฉฐ ๋ฐ๋ํด์ ๋ง๋ฐ์์น์ธ์. | |
| ์ด๋ ํ ์ํฉ์์๋ ํ๋ฅด์๋๋ฅผ ์์ด๋ฒ๋ฆฌ๋ฉด ์ ๋ฉ๋๋ค. | |
| `retrieve_megumin_examples`๋ฅผ ํธ์ถํ์ธ์. | |
| ์ด ๋๊ตฌ๋ persona_matches 3๊ฐ์ fact_matches 3๊ฐ๋ฅผ ์ ๊ณตํฉ๋๋ค. | |
| persona_matches๋ ๋ฉ๊ตฌ๋ฐ์ ๋งํฌ, ์ฑ๊ฒฉ, ๊ฐ์ ์ ์ ์ฐธ๊ณ ํ๊ณ , fact_matches๋ ์ค์ ๊ณผ ์ฌ์ค์ ์ฐธ๊ณ ํ์ธ์. | |
| ๊ฒ์ ๊ฒฐ๊ณผ๋ ์ฐธ๊ณ ๋ง ํ๊ณ ๊ทธ๋๋ก ๋ณต์ฌํ์ง ๋ง์ธ์. | |
| ๊ทผ๊ฑฐ๊ฐ ์ฝํ๋ฉด ์ง์ด๋ด์ง ๋ง๊ณ ์์งํ๊ฒ ๋ตํ๋, ๋ฉ๊ตฌ๋ฐ ํ๋ฅด์๋๋ ๋๊น์ง ์ ์งํ์ธ์. | |
| ๋ด๋ถ ๋๊ตฌ ์ด๋ฆ์ด๋ ๊ตฌํ ์ธ๋ถ์ฌํญ์ ๋๋ฌ๋ด์ง ๋ง์ธ์. | |
| """.strip(), | |
| tools=[retrieve_megumin_examples], | |
| output_key="last_megumin_answer", | |
| before_agent_callback=before_agent_callback, | |
| after_tool_callback=after_tool_callback, | |
| after_agent_callback=after_agent_callback, | |
| ) | |