Spaces:
Sleeping
Sleeping
File size: 7,333 Bytes
7245599 c8b5814 6d0bcdd 7245599 c7b2544 7245599 ec10f07 7245599 6d0bcdd c8b5814 7245599 6d0bcdd c8b5814 6d0bcdd c8b5814 6d0bcdd c8b5814 6d0bcdd 7245599 ec10f07 7245599 c8b5814 c7b2544 7245599 c7b2544 7245599 c7b2544 7245599 c7b2544 7245599 c7b2544 7245599 c7b2544 7245599 6d0bcdd c7b2544 7245599 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 | from __future__ import annotations
import uuid
from dataclasses import dataclass
from typing import Any
from typing import AsyncIterator
from typing import Iterable
from .bootstrap import bootstrap_environment
bootstrap_environment()
from google.adk.agents.run_config import RunConfig
from google.adk.agents.run_config import StreamingMode
from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService
from google.genai import types
from .agent import FACT_RETRIEVER
from .agent import PERSONA_RETRIEVER
from .agent import root_agent
APP_NAME = "megumin_rag_app"
MAX_TURNS_IN_CONTEXT = 6
SUMMARY_MAX_CHARS = 800
SUMMARY_USER_LIMIT = 3
SUMMARY_ASSISTANT_LIMIT = 2
SUMMARY_ITEM_CHARS = 42
@dataclass
class ChatServices:
runner: Runner
session_service: InMemorySessionService
def _event_texts(events: Iterable) -> list[str]:
lines: list[str] = []
for event in events:
if not getattr(event, "content", None) or not event.content.parts:
continue
text_parts = [
getattr(part, "text", None)
for part in event.content.parts
if getattr(part, "text", None)
]
if not text_parts:
continue
author = "user" if event.author == "user" else "assistant"
lines.append(f"{author}: {' '.join(text_parts).strip()}")
return lines
def _compact_summary_item(text: str, limit: int = SUMMARY_ITEM_CHARS) -> str:
compact = " ".join(str(text or "").split()).strip()
if len(compact) <= limit:
return compact
return compact[: limit - 3].rstrip() + "..."
def _parse_summary_map(value: Any) -> dict[str, list[str]]:
if not isinstance(value, dict):
return {
"user_topics": [],
"assistant_points": [],
}
return {
"user_topics": [
str(item) for item in value.get("user_topics", []) if str(item).strip()
],
"assistant_points": [
str(item)
for item in value.get("assistant_points", [])
if str(item).strip()
],
}
def _merge_unique_tail(previous: list[str], additions: list[str], limit: int) -> list[str]:
merged: list[str] = []
for item in [*previous, *additions]:
if not item or item in merged:
continue
merged.append(item)
return merged[-limit:]
def _compress_summary(
previous_summary_map: Any,
new_lines: list[str],
) -> dict[str, list[str]]:
summary_map = _parse_summary_map(previous_summary_map)
user_lines = [
_compact_summary_item(line.removeprefix("user:").strip())
for line in new_lines
if line.startswith("user:")
]
assistant_lines = [
_compact_summary_item(line.removeprefix("assistant:").strip())
for line in new_lines
if line.startswith("assistant:")
]
summary_map["user_topics"] = _merge_unique_tail(
summary_map["user_topics"],
user_lines,
SUMMARY_USER_LIMIT,
)
summary_map["assistant_points"] = _merge_unique_tail(
summary_map["assistant_points"],
assistant_lines,
SUMMARY_ASSISTANT_LIMIT,
)
return summary_map
def _render_summary(summary_map: dict[str, list[str]]) -> str:
chunks: list[str] = []
if summary_map.get("user_topics"):
chunks.append("user_topics: " + " ; ".join(summary_map["user_topics"]))
if summary_map.get("assistant_points"):
chunks.append("assistant_points: " + " ; ".join(summary_map["assistant_points"]))
rendered = "\n".join(chunks).strip()
if len(rendered) <= SUMMARY_MAX_CHARS:
return rendered
return rendered[: SUMMARY_MAX_CHARS - 3].rstrip() + "..."
def _trim_session_history(
services: ChatServices,
*,
user_id: str,
session_id: str,
) -> None:
session_store = services.session_service.sessions
storage_session = session_store.get(APP_NAME, {}).get(user_id, {}).get(session_id)
if storage_session is None:
return
max_events = MAX_TURNS_IN_CONTEXT * 2
if len(storage_session.events) <= max_events:
return
overflow = storage_session.events[:-max_events]
storage_session.events = storage_session.events[-max_events:]
previous_summary_map = storage_session.state.get("conversation_summary_map", {})
summary_map = _compress_summary(previous_summary_map, _event_texts(overflow))
storage_session.state["conversation_summary_map"] = summary_map
storage_session.state["conversation_summary"] = _render_summary(summary_map)
def create_chat_services() -> ChatServices:
PERSONA_RETRIEVER.warmup()
FACT_RETRIEVER.warmup()
session_service = InMemorySessionService()
runner = Runner(
agent=root_agent,
app_name=APP_NAME,
session_service=session_service,
)
return ChatServices(runner=runner, session_service=session_service)
def _extract_text(event: Any) -> str:
if not getattr(event, "content", None) or not getattr(event.content, "parts", None):
return ""
texts = [
getattr(part, "text", "")
for part in event.content.parts
if getattr(part, "text", "")
]
return "".join(texts).strip()
async def stream_chat(
user_message: str,
services: ChatServices,
session_id: str | None = None,
user_id: str = "local-user",
) -> AsyncIterator[tuple[str, str]]:
active_session_id = session_id or str(uuid.uuid4())
existing_session = await services.session_service.get_session(
app_name=APP_NAME,
user_id=user_id,
session_id=active_session_id,
)
if existing_session is None:
await services.session_service.create_session(
app_name=APP_NAME,
user_id=user_id,
session_id=active_session_id,
)
streamed_text = ""
final_text = ""
run_config = RunConfig(streaming_mode=StreamingMode.SSE)
async for event in services.runner.run_async(
user_id=user_id,
session_id=active_session_id,
new_message=types.UserContent(parts=[types.Part(text=user_message)]),
run_config=run_config,
):
if getattr(event, "author", None) == "user":
continue
text = _extract_text(event)
if not text:
continue
if getattr(event, "partial", None) is True:
streamed_text += text
yield streamed_text, active_session_id
continue
if getattr(event, "is_final_response", None) and event.is_final_response():
final_text = text
if final_text and final_text != streamed_text:
streamed_text = final_text
yield streamed_text, active_session_id
_trim_session_history(
services,
user_id=user_id,
session_id=active_session_id,
)
async def chat_once(
user_message: str,
services: ChatServices,
session_id: str | None = None,
user_id: str = "local-user",
) -> tuple[str, str]:
last_text = ""
active_session_id = session_id or str(uuid.uuid4())
async for chunk_text, active_session_id in stream_chat(
user_message=user_message,
services=services,
session_id=active_session_id,
user_id=user_id,
):
last_text = chunk_text
return last_text, active_session_id
|