Junhoee commited on
Commit
6d0bcdd
·
verified ·
1 Parent(s): 1c7c36f

Update megumin_agent/chat.py

Browse files
Files changed (1) hide show
  1. megumin_agent/chat.py +64 -0
megumin_agent/chat.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  import uuid
4
  from dataclasses import dataclass
 
5
 
6
  from .bootstrap import bootstrap_environment
7
 
@@ -17,6 +18,8 @@ from .retrieval import JsonQaRetriever
17
 
18
 
19
  APP_NAME = "megumin_rag_app"
 
 
20
 
21
 
22
  @dataclass
@@ -25,6 +28,61 @@ class ChatServices:
25
  session_service: InMemorySessionService
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def create_chat_services() -> ChatServices:
29
  JsonQaRetriever(DATASET_DIR).warmup()
30
  session_service = InMemorySessionService()
@@ -68,4 +126,10 @@ async def chat_once(
68
  if text and event.author != "user":
69
  last_text = text
70
 
 
 
 
 
 
 
71
  return last_text, active_session_id
 
2
 
3
  import uuid
4
  from dataclasses import dataclass
5
+ from typing import Iterable
6
 
7
  from .bootstrap import bootstrap_environment
8
 
 
18
 
19
 
20
  APP_NAME = "megumin_rag_app"
21
+ MAX_TURNS_IN_CONTEXT = 6
22
+ SUMMARY_MAX_CHARS = 800
23
 
24
 
25
  @dataclass
 
28
  session_service: InMemorySessionService
29
 
30
 
31
+ def _event_texts(events: Iterable) -> list[str]:
32
+ lines: list[str] = []
33
+ for event in events:
34
+ if not getattr(event, "content", None) or not event.content.parts:
35
+ continue
36
+ text_parts = [
37
+ getattr(part, "text", None)
38
+ for part in event.content.parts
39
+ if getattr(part, "text", None)
40
+ ]
41
+ if not text_parts:
42
+ continue
43
+ author = "user" if event.author == "user" else "assistant"
44
+ lines.append(f"{author}: {' '.join(text_parts).strip()}")
45
+ return lines
46
+
47
+
48
+ def _compress_summary(previous_summary: str, new_lines: list[str]) -> str:
49
+ pieces = [previous_summary.strip()] if previous_summary.strip() else []
50
+ if new_lines:
51
+ pieces.append(" / ".join(new_lines))
52
+ summary = " | ".join(piece for piece in pieces if piece).strip()
53
+ if len(summary) <= SUMMARY_MAX_CHARS:
54
+ return summary
55
+ return "..." + summary[-(SUMMARY_MAX_CHARS - 3) :]
56
+
57
+
58
+ def _trim_session_history(
59
+ services: ChatServices,
60
+ *,
61
+ user_id: str,
62
+ session_id: str,
63
+ ) -> None:
64
+ session_store = services.session_service.sessions
65
+ storage_session = (
66
+ session_store.get(APP_NAME, {})
67
+ .get(user_id, {})
68
+ .get(session_id)
69
+ )
70
+ if storage_session is None:
71
+ return
72
+
73
+ max_events = MAX_TURNS_IN_CONTEXT * 2
74
+ if len(storage_session.events) <= max_events:
75
+ return
76
+
77
+ overflow = storage_session.events[:-max_events]
78
+ storage_session.events = storage_session.events[-max_events:]
79
+ previous_summary = str(storage_session.state.get("conversation_summary", ""))
80
+ storage_session.state["conversation_summary"] = _compress_summary(
81
+ previous_summary,
82
+ _event_texts(overflow),
83
+ )
84
+
85
+
86
  def create_chat_services() -> ChatServices:
87
  JsonQaRetriever(DATASET_DIR).warmup()
88
  session_service = InMemorySessionService()
 
126
  if text and event.author != "user":
127
  last_text = text
128
 
129
+ _trim_session_history(
130
+ services,
131
+ user_id=user_id,
132
+ session_id=active_session_id,
133
+ )
134
+
135
  return last_text, active_session_id