anhkhoiphan commited on
Commit
c812bd3
·
1 Parent(s): 55834a6

Bổ sung tính năng scale buffer memory theo số thành viên trong nhóm

Browse files
Files changed (3) hide show
  1. api_requirements.txt +1 -0
  2. config.py +4 -0
  3. conversation_memory.py +79 -31
api_requirements.txt CHANGED
@@ -15,3 +15,4 @@ pdfplumber>=0.11.0
15
  openai>=1.50.0
16
  fastembed>=0.4.0
17
  qdrant-client>=1.10.0
 
 
15
  openai>=1.50.0
16
  fastembed>=0.4.0
17
  qdrant-client>=1.10.0
18
+ supabase>=2.0.0
config.py CHANGED
@@ -29,6 +29,10 @@ if REDIS_URL:
29
  QDRANT_URL = os.getenv("QDRANT_URL", "")
30
  QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "")
31
 
 
 
 
 
32
  # Logging
33
  LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
34
 
 
29
  QDRANT_URL = os.getenv("QDRANT_URL", "")
30
  QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "")
31
 
32
+ # Supabase
33
+ SUPABASE_URL = os.getenv("SUPABASE_URL", "")
34
+ SUPABASE_SERVICE_ROLE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY", "")
35
+
36
  # Logging
37
  LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
38
 
conversation_memory.py CHANGED
@@ -1,12 +1,14 @@
1
  """
2
  ConversationSummaryBufferMemory — Qdrant-backed.
3
 
4
- Buffer: tối đa MAX_BUFFER tin nhắn.
5
- Khi vượt quá: tóm tắt SUMMARIZE_COUNT tin cũ nhất, giữ lại KEEP_RECENT tin mới nhất.
6
- Lưu cả summary buffer trên Qdrant (payload-only, dummy vector).
 
 
 
7
  """
8
 
9
- import json
10
  import logging
11
  import uuid
12
  from typing import Optional
@@ -15,27 +17,27 @@ from langchain_core.messages import HumanMessage, SystemMessage
15
  from qdrant_client import QdrantClient
16
  from qdrant_client.models import Distance, PointStruct, VectorParams
17
 
18
- from src.config import QDRANT_API_KEY, QDRANT_URL
19
  from src.llm import llm
20
 
21
  logger = logging.getLogger(__name__)
22
 
23
- MAX_BUFFER = 100 # ngưỡng trigger summarization
24
- SUMMARIZE_COUNT = 60 # số tin nhắn cũ nhất sẽ được tóm tắt
25
- KEEP_RECENT = 40 # số tin nhắn giữ lại trong buffer sau summarization
26
-
27
  _COLLECTION = "conversation_memory"
28
  _DUMMY_VECTOR = [0.0]
29
 
30
- _client: Optional[QdrantClient] = None
 
31
 
32
 
33
- def _get_client() -> QdrantClient:
34
- global _client
35
- if _client is None:
36
- _client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
37
- _ensure_collection(_client)
38
- return _client
 
 
39
 
40
 
41
  def _ensure_collection(client: QdrantClient) -> None:
@@ -52,6 +54,51 @@ def _point_id(conversation_id: str) -> str:
52
  return str(uuid.uuid5(uuid.NAMESPACE_DNS, f"conv:{conversation_id}"))
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  # ── Load / Save ───────────────────────────────────────────────────────────────
56
 
57
  def load(conversation_id: str) -> tuple[str, list[dict]]:
@@ -59,7 +106,7 @@ def load(conversation_id: str) -> tuple[str, list[dict]]:
59
  if not QDRANT_URL:
60
  return "", []
61
  try:
62
- results = _get_client().retrieve(
63
  collection_name=_COLLECTION,
64
  ids=[_point_id(conversation_id)],
65
  with_payload=True,
@@ -79,7 +126,7 @@ def save(conversation_id: str, summary: str, buffer: list[dict]) -> None:
79
  if not QDRANT_URL:
80
  return
81
  try:
82
- _get_client().upsert(
83
  collection_name=_COLLECTION,
84
  points=[PointStruct(
85
  id=_point_id(conversation_id),
@@ -134,19 +181,19 @@ def _summarize(existing_summary: str, messages: list[dict]) -> str:
134
  def add_turn(conversation_id: str, user_msg: str, ai_msg: str) -> None:
135
  """Thêm 1 lượt user+assistant vào buffer, trigger summarize nếu cần."""
136
  summary, buffer = load(conversation_id)
 
137
 
138
  buffer.append({"role": "user", "content": user_msg})
139
  buffer.append({"role": "assistant", "content": ai_msg})
140
 
141
- if len(buffer) > MAX_BUFFER:
142
- to_summarize = buffer[:SUMMARIZE_COUNT]
143
- buffer = buffer[SUMMARIZE_COUNT:] # giữ lại phần còn lại
144
- # Đảm bảo buffer không vượt KEEP_RECENT sau khi cắt
145
- if len(buffer) > KEEP_RECENT:
146
- buffer = buffer[-KEEP_RECENT:]
147
  logger.info(
148
  "[Memory] Buffer vượt %d, tóm tắt %d tin → giữ %d tin.",
149
- MAX_BUFFER, len(to_summarize), len(buffer),
150
  )
151
  summary = _summarize(summary, to_summarize)
152
 
@@ -157,7 +204,7 @@ def seed_room(conversation_id: str, messages: list[dict]) -> None:
157
  """
158
  Seed Qdrant buffer từ danh sách tin nhắn Redis thô.
159
  Mỗi message được chuyển thành role='user', content='[ts] name (id): content'.
160
- Nếu vượt MAX_BUFFER thì tự động summarize trước khi lưu.
161
  """
162
  _NAME_FIELDS = ["sender_username", "username", "u_username", "name", "u_name",
163
  "senderName", "displayName", "display_name", "fullName", "sender_id"]
@@ -179,12 +226,13 @@ def seed_room(conversation_id: str, messages: list[dict]) -> None:
179
  label = f"{name} ({sid})" if sid != name else name
180
  buffer.append({"role": "user", "content": f"[{ts}] {label}: {content}"})
181
 
 
182
  summary = ""
183
- while len(buffer) > MAX_BUFFER:
184
- to_summarize = buffer[:SUMMARIZE_COUNT]
185
- buffer = buffer[SUMMARIZE_COUNT:]
186
- if len(buffer) > KEEP_RECENT:
187
- buffer = buffer[-KEEP_RECENT:]
188
  logger.info(
189
  "[Memory] seed_room: tóm tắt %d tin → giữ %d tin còn lại",
190
  len(to_summarize), len(buffer),
 
1
  """
2
  ConversationSummaryBufferMemory — Qdrant-backed.
3
 
4
+ Giới hạn buffer tính theo số thành viên room (n):
5
+ MAX_BUFFER = 10n
6
+ SUMMARIZE_COUNT = 6n (số tin nhất được tóm tắt khi vượt ngưỡng)
7
+ KEEP_RECENT = 4n (số tin giữ lại trong buffer sau khi tóm tắt)
8
+
9
+ Fallback n = 20 nếu không kết nối được Supabase hoặc không phải room.
10
  """
11
 
 
12
  import logging
13
  import uuid
14
  from typing import Optional
 
17
  from qdrant_client import QdrantClient
18
  from qdrant_client.models import Distance, PointStruct, VectorParams
19
 
20
+ from src.config import QDRANT_API_KEY, QDRANT_URL, SUPABASE_SERVICE_ROLE_KEY, SUPABASE_URL
21
  from src.llm import llm
22
 
23
  logger = logging.getLogger(__name__)
24
 
25
+ _DEFAULT_N = 20
 
 
 
26
  _COLLECTION = "conversation_memory"
27
  _DUMMY_VECTOR = [0.0]
28
 
29
+ _qdrant_client: Optional[QdrantClient] = None
30
+ _sb_client = None
31
 
32
 
33
+ # ── Qdrant client ─────────────────────────────────────────────────────────────
34
+
35
+ def _get_qdrant() -> QdrantClient:
36
+ global _qdrant_client
37
+ if _qdrant_client is None:
38
+ _qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
39
+ _ensure_collection(_qdrant_client)
40
+ return _qdrant_client
41
 
42
 
43
  def _ensure_collection(client: QdrantClient) -> None:
 
54
  return str(uuid.uuid5(uuid.NAMESPACE_DNS, f"conv:{conversation_id}"))
55
 
56
 
57
+ # ── Supabase client ───────────────────────────────────────────────────────────
58
+
59
+ def _get_sb():
60
+ global _sb_client
61
+ if _sb_client is None and SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY:
62
+ try:
63
+ from supabase import create_client
64
+ _sb_client = create_client(SUPABASE_URL, SUPABASE_SERVICE_ROLE_KEY)
65
+ except Exception:
66
+ logger.exception("[Memory] Không khởi tạo được Supabase client.")
67
+ return _sb_client
68
+
69
+
70
+ # ── Dynamic limits ────────────────────────────────────────────────────────────
71
+
72
+ def _get_member_count(conversation_id: str) -> int:
73
+ """Lấy số thành viên trong room từ Supabase. Fallback về _DEFAULT_N."""
74
+ if not conversation_id.startswith("room-"):
75
+ return _DEFAULT_N
76
+
77
+ sb = _get_sb()
78
+ if sb is None:
79
+ return _DEFAULT_N
80
+
81
+ room_id = conversation_id.removeprefix("room-")
82
+ try:
83
+ res = (
84
+ sb.table("room_members")
85
+ .select("user_id", count="exact")
86
+ .eq("room_id", room_id)
87
+ .execute()
88
+ )
89
+ n = res.count or 0
90
+ return n if n > 0 else _DEFAULT_N
91
+ except Exception:
92
+ logger.exception("[Memory] Lỗi lấy số thành viên room '%s'", room_id)
93
+ return _DEFAULT_N
94
+
95
+
96
+ def _get_limits(conversation_id: str) -> tuple[int, int, int]:
97
+ """Trả về (max_buffer, summarize_count, keep_recent) theo số thành viên n."""
98
+ n = _get_member_count(conversation_id)
99
+ return 10 * n, 6 * n, 4 * n
100
+
101
+
102
  # ── Load / Save ───────────────────────────────────────────────────────────────
103
 
104
  def load(conversation_id: str) -> tuple[str, list[dict]]:
 
106
  if not QDRANT_URL:
107
  return "", []
108
  try:
109
+ results = _get_qdrant().retrieve(
110
  collection_name=_COLLECTION,
111
  ids=[_point_id(conversation_id)],
112
  with_payload=True,
 
126
  if not QDRANT_URL:
127
  return
128
  try:
129
+ _get_qdrant().upsert(
130
  collection_name=_COLLECTION,
131
  points=[PointStruct(
132
  id=_point_id(conversation_id),
 
181
  def add_turn(conversation_id: str, user_msg: str, ai_msg: str) -> None:
182
  """Thêm 1 lượt user+assistant vào buffer, trigger summarize nếu cần."""
183
  summary, buffer = load(conversation_id)
184
+ max_buffer, summarize_count, keep_recent = _get_limits(conversation_id)
185
 
186
  buffer.append({"role": "user", "content": user_msg})
187
  buffer.append({"role": "assistant", "content": ai_msg})
188
 
189
+ if len(buffer) > max_buffer:
190
+ to_summarize = buffer[:summarize_count]
191
+ buffer = buffer[summarize_count:]
192
+ if len(buffer) > keep_recent:
193
+ buffer = buffer[-keep_recent:]
 
194
  logger.info(
195
  "[Memory] Buffer vượt %d, tóm tắt %d tin → giữ %d tin.",
196
+ max_buffer, len(to_summarize), len(buffer),
197
  )
198
  summary = _summarize(summary, to_summarize)
199
 
 
204
  """
205
  Seed Qdrant buffer từ danh sách tin nhắn Redis thô.
206
  Mỗi message được chuyển thành role='user', content='[ts] name (id): content'.
207
+ Nếu vượt max_buffer thì tự động summarize trước khi lưu.
208
  """
209
  _NAME_FIELDS = ["sender_username", "username", "u_username", "name", "u_name",
210
  "senderName", "displayName", "display_name", "fullName", "sender_id"]
 
226
  label = f"{name} ({sid})" if sid != name else name
227
  buffer.append({"role": "user", "content": f"[{ts}] {label}: {content}"})
228
 
229
+ max_buffer, summarize_count, keep_recent = _get_limits(conversation_id)
230
  summary = ""
231
+ while len(buffer) > max_buffer:
232
+ to_summarize = buffer[:summarize_count]
233
+ buffer = buffer[summarize_count:]
234
+ if len(buffer) > keep_recent:
235
+ buffer = buffer[-keep_recent:]
236
  logger.info(
237
  "[Memory] seed_room: tóm tắt %d tin → giữ %d tin còn lại",
238
  len(to_summarize), len(buffer),