Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import uuid | |
| from dataclasses import dataclass | |
| from typing import Any | |
| from typing import AsyncIterator | |
| from typing import Iterable | |
| from .bootstrap import bootstrap_environment | |
| bootstrap_environment() | |
| from google.adk.agents.run_config import RunConfig | |
| from google.adk.agents.run_config import StreamingMode | |
| from google.adk.runners import Runner | |
| from google.adk.sessions import InMemorySessionService | |
| from google.genai import types | |
| from .agent import FACT_RETRIEVER | |
| from .agent import PERSONA_RETRIEVER | |
| from .agent import root_agent | |
| APP_NAME = "megumin_rag_app" | |
| MAX_TURNS_IN_CONTEXT = 6 | |
| SUMMARY_MAX_CHARS = 800 | |
| SUMMARY_USER_LIMIT = 3 | |
| SUMMARY_ASSISTANT_LIMIT = 2 | |
| SUMMARY_ITEM_CHARS = 42 | |
| class ChatServices: | |
| runner: Runner | |
| session_service: InMemorySessionService | |
| def _event_texts(events: Iterable) -> list[str]: | |
| lines: list[str] = [] | |
| for event in events: | |
| if not getattr(event, "content", None) or not event.content.parts: | |
| continue | |
| text_parts = [ | |
| getattr(part, "text", None) | |
| for part in event.content.parts | |
| if getattr(part, "text", None) | |
| ] | |
| if not text_parts: | |
| continue | |
| author = "user" if event.author == "user" else "assistant" | |
| lines.append(f"{author}: {' '.join(text_parts).strip()}") | |
| return lines | |
| def _compact_summary_item(text: str, limit: int = SUMMARY_ITEM_CHARS) -> str: | |
| compact = " ".join(str(text or "").split()).strip() | |
| if len(compact) <= limit: | |
| return compact | |
| return compact[: limit - 3].rstrip() + "..." | |
| def _parse_summary_map(value: Any) -> dict[str, list[str]]: | |
| if not isinstance(value, dict): | |
| return { | |
| "user_topics": [], | |
| "assistant_points": [], | |
| } | |
| return { | |
| "user_topics": [ | |
| str(item) for item in value.get("user_topics", []) if str(item).strip() | |
| ], | |
| "assistant_points": [ | |
| str(item) | |
| for item in value.get("assistant_points", []) | |
| if str(item).strip() | |
| ], | |
| } | |
| def _merge_unique_tail(previous: list[str], additions: list[str], limit: int) -> list[str]: | |
| merged: list[str] = [] | |
| for item in [*previous, *additions]: | |
| if not item or item in merged: | |
| continue | |
| merged.append(item) | |
| return merged[-limit:] | |
| def _compress_summary( | |
| previous_summary_map: Any, | |
| new_lines: list[str], | |
| ) -> dict[str, list[str]]: | |
| summary_map = _parse_summary_map(previous_summary_map) | |
| user_lines = [ | |
| _compact_summary_item(line.removeprefix("user:").strip()) | |
| for line in new_lines | |
| if line.startswith("user:") | |
| ] | |
| assistant_lines = [ | |
| _compact_summary_item(line.removeprefix("assistant:").strip()) | |
| for line in new_lines | |
| if line.startswith("assistant:") | |
| ] | |
| summary_map["user_topics"] = _merge_unique_tail( | |
| summary_map["user_topics"], | |
| user_lines, | |
| SUMMARY_USER_LIMIT, | |
| ) | |
| summary_map["assistant_points"] = _merge_unique_tail( | |
| summary_map["assistant_points"], | |
| assistant_lines, | |
| SUMMARY_ASSISTANT_LIMIT, | |
| ) | |
| return summary_map | |
| def _render_summary(summary_map: dict[str, list[str]]) -> str: | |
| chunks: list[str] = [] | |
| if summary_map.get("user_topics"): | |
| chunks.append("user_topics: " + " ; ".join(summary_map["user_topics"])) | |
| if summary_map.get("assistant_points"): | |
| chunks.append("assistant_points: " + " ; ".join(summary_map["assistant_points"])) | |
| rendered = "\n".join(chunks).strip() | |
| if len(rendered) <= SUMMARY_MAX_CHARS: | |
| return rendered | |
| return rendered[: SUMMARY_MAX_CHARS - 3].rstrip() + "..." | |
| def _trim_session_history( | |
| services: ChatServices, | |
| *, | |
| user_id: str, | |
| session_id: str, | |
| ) -> None: | |
| session_store = services.session_service.sessions | |
| storage_session = session_store.get(APP_NAME, {}).get(user_id, {}).get(session_id) | |
| if storage_session is None: | |
| return | |
| max_events = MAX_TURNS_IN_CONTEXT * 2 | |
| if len(storage_session.events) <= max_events: | |
| return | |
| overflow = storage_session.events[:-max_events] | |
| storage_session.events = storage_session.events[-max_events:] | |
| previous_summary_map = storage_session.state.get("conversation_summary_map", {}) | |
| summary_map = _compress_summary(previous_summary_map, _event_texts(overflow)) | |
| storage_session.state["conversation_summary_map"] = summary_map | |
| storage_session.state["conversation_summary"] = _render_summary(summary_map) | |
| def create_chat_services() -> ChatServices: | |
| PERSONA_RETRIEVER.warmup() | |
| FACT_RETRIEVER.warmup() | |
| session_service = InMemorySessionService() | |
| runner = Runner( | |
| agent=root_agent, | |
| app_name=APP_NAME, | |
| session_service=session_service, | |
| ) | |
| return ChatServices(runner=runner, session_service=session_service) | |
| def _extract_text(event: Any) -> str: | |
| if not getattr(event, "content", None) or not getattr(event.content, "parts", None): | |
| return "" | |
| texts = [ | |
| getattr(part, "text", "") | |
| for part in event.content.parts | |
| if getattr(part, "text", "") | |
| ] | |
| return "".join(texts).strip() | |
| async def stream_chat( | |
| user_message: str, | |
| services: ChatServices, | |
| session_id: str | None = None, | |
| user_id: str = "local-user", | |
| ) -> AsyncIterator[tuple[str, str]]: | |
| active_session_id = session_id or str(uuid.uuid4()) | |
| existing_session = await services.session_service.get_session( | |
| app_name=APP_NAME, | |
| user_id=user_id, | |
| session_id=active_session_id, | |
| ) | |
| if existing_session is None: | |
| await services.session_service.create_session( | |
| app_name=APP_NAME, | |
| user_id=user_id, | |
| session_id=active_session_id, | |
| ) | |
| streamed_text = "" | |
| final_text = "" | |
| run_config = RunConfig(streaming_mode=StreamingMode.SSE) | |
| async for event in services.runner.run_async( | |
| user_id=user_id, | |
| session_id=active_session_id, | |
| new_message=types.UserContent(parts=[types.Part(text=user_message)]), | |
| run_config=run_config, | |
| ): | |
| if getattr(event, "author", None) == "user": | |
| continue | |
| text = _extract_text(event) | |
| if not text: | |
| continue | |
| if getattr(event, "partial", None) is True: | |
| streamed_text += text | |
| yield streamed_text, active_session_id | |
| continue | |
| if getattr(event, "is_final_response", None) and event.is_final_response(): | |
| final_text = text | |
| if final_text and final_text != streamed_text: | |
| streamed_text = final_text | |
| yield streamed_text, active_session_id | |
| _trim_session_history( | |
| services, | |
| user_id=user_id, | |
| session_id=active_session_id, | |
| ) | |
| async def chat_once( | |
| user_message: str, | |
| services: ChatServices, | |
| session_id: str | None = None, | |
| user_id: str = "local-user", | |
| ) -> tuple[str, str]: | |
| last_text = "" | |
| active_session_id = session_id or str(uuid.uuid4()) | |
| async for chunk_text, active_session_id in stream_chat( | |
| user_message=user_message, | |
| services=services, | |
| session_id=active_session_id, | |
| user_id=user_id, | |
| ): | |
| last_text = chunk_text | |
| return last_text, active_session_id | |