from __future__ import annotations import uuid from dataclasses import dataclass from typing import Any from typing import AsyncIterator from typing import Iterable from .bootstrap import bootstrap_environment bootstrap_environment() from google.adk.agents.run_config import RunConfig from google.adk.agents.run_config import StreamingMode from google.adk.runners import Runner from google.adk.sessions import InMemorySessionService from google.genai import types from .agent import FACT_RETRIEVER from .agent import PERSONA_RETRIEVER from .agent import root_agent APP_NAME = "megumin_rag_app" MAX_TURNS_IN_CONTEXT = 6 SUMMARY_MAX_CHARS = 800 SUMMARY_USER_LIMIT = 3 SUMMARY_ASSISTANT_LIMIT = 2 SUMMARY_ITEM_CHARS = 42 @dataclass class ChatServices: runner: Runner session_service: InMemorySessionService def _event_texts(events: Iterable) -> list[str]: lines: list[str] = [] for event in events: if not getattr(event, "content", None) or not event.content.parts: continue text_parts = [ getattr(part, "text", None) for part in event.content.parts if getattr(part, "text", None) ] if not text_parts: continue author = "user" if event.author == "user" else "assistant" lines.append(f"{author}: {' '.join(text_parts).strip()}") return lines def _compact_summary_item(text: str, limit: int = SUMMARY_ITEM_CHARS) -> str: compact = " ".join(str(text or "").split()).strip() if len(compact) <= limit: return compact return compact[: limit - 3].rstrip() + "..." def _parse_summary_map(value: Any) -> dict[str, list[str]]: if not isinstance(value, dict): return { "user_topics": [], "assistant_points": [], } return { "user_topics": [ str(item) for item in value.get("user_topics", []) if str(item).strip() ], "assistant_points": [ str(item) for item in value.get("assistant_points", []) if str(item).strip() ], } def _merge_unique_tail(previous: list[str], additions: list[str], limit: int) -> list[str]: merged: list[str] = [] for item in [*previous, *additions]: if not item or item in merged: continue merged.append(item) return merged[-limit:] def _compress_summary( previous_summary_map: Any, new_lines: list[str], ) -> dict[str, list[str]]: summary_map = _parse_summary_map(previous_summary_map) user_lines = [ _compact_summary_item(line.removeprefix("user:").strip()) for line in new_lines if line.startswith("user:") ] assistant_lines = [ _compact_summary_item(line.removeprefix("assistant:").strip()) for line in new_lines if line.startswith("assistant:") ] summary_map["user_topics"] = _merge_unique_tail( summary_map["user_topics"], user_lines, SUMMARY_USER_LIMIT, ) summary_map["assistant_points"] = _merge_unique_tail( summary_map["assistant_points"], assistant_lines, SUMMARY_ASSISTANT_LIMIT, ) return summary_map def _render_summary(summary_map: dict[str, list[str]]) -> str: chunks: list[str] = [] if summary_map.get("user_topics"): chunks.append("user_topics: " + " ; ".join(summary_map["user_topics"])) if summary_map.get("assistant_points"): chunks.append("assistant_points: " + " ; ".join(summary_map["assistant_points"])) rendered = "\n".join(chunks).strip() if len(rendered) <= SUMMARY_MAX_CHARS: return rendered return rendered[: SUMMARY_MAX_CHARS - 3].rstrip() + "..." def _trim_session_history( services: ChatServices, *, user_id: str, session_id: str, ) -> None: session_store = services.session_service.sessions storage_session = session_store.get(APP_NAME, {}).get(user_id, {}).get(session_id) if storage_session is None: return max_events = MAX_TURNS_IN_CONTEXT * 2 if len(storage_session.events) <= max_events: return overflow = storage_session.events[:-max_events] storage_session.events = storage_session.events[-max_events:] previous_summary_map = storage_session.state.get("conversation_summary_map", {}) summary_map = _compress_summary(previous_summary_map, _event_texts(overflow)) storage_session.state["conversation_summary_map"] = summary_map storage_session.state["conversation_summary"] = _render_summary(summary_map) def create_chat_services() -> ChatServices: PERSONA_RETRIEVER.warmup() FACT_RETRIEVER.warmup() session_service = InMemorySessionService() runner = Runner( agent=root_agent, app_name=APP_NAME, session_service=session_service, ) return ChatServices(runner=runner, session_service=session_service) def _extract_text(event: Any) -> str: if not getattr(event, "content", None) or not getattr(event.content, "parts", None): return "" texts = [ getattr(part, "text", "") for part in event.content.parts if getattr(part, "text", "") ] return "".join(texts).strip() async def stream_chat( user_message: str, services: ChatServices, session_id: str | None = None, user_id: str = "local-user", ) -> AsyncIterator[tuple[str, str]]: active_session_id = session_id or str(uuid.uuid4()) existing_session = await services.session_service.get_session( app_name=APP_NAME, user_id=user_id, session_id=active_session_id, ) if existing_session is None: await services.session_service.create_session( app_name=APP_NAME, user_id=user_id, session_id=active_session_id, ) streamed_text = "" final_text = "" run_config = RunConfig(streaming_mode=StreamingMode.SSE) async for event in services.runner.run_async( user_id=user_id, session_id=active_session_id, new_message=types.UserContent(parts=[types.Part(text=user_message)]), run_config=run_config, ): if getattr(event, "author", None) == "user": continue text = _extract_text(event) if not text: continue if getattr(event, "partial", None) is True: streamed_text += text yield streamed_text, active_session_id continue if getattr(event, "is_final_response", None) and event.is_final_response(): final_text = text if final_text and final_text != streamed_text: streamed_text = final_text yield streamed_text, active_session_id _trim_session_history( services, user_id=user_id, session_id=active_session_id, ) async def chat_once( user_message: str, services: ChatServices, session_id: str | None = None, user_id: str = "local-user", ) -> tuple[str, str]: last_text = "" active_session_id = session_id or str(uuid.uuid4()) async for chunk_text, active_session_id in stream_chat( user_message=user_message, services=services, session_id=active_session_id, user_id=user_id, ): last_text = chunk_text return last_text, active_session_id