anhkhoiphan commited on
Commit
66ba458
·
1 Parent(s): 91fe7ce

Thêm tool vẽ chart

Browse files
Files changed (2) hide show
  1. tools/chart.py +201 -0
  2. tools/chat_tools.py +2 -0
tools/chart.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chart summarization tool — reads messages, counts unique user opinions, formats for charting.
3
+ """
4
+
5
+ import time
6
+ import logging
7
+ from typing import List
8
+
9
+ from pydantic import BaseModel, Field
10
+ from langchain_core.output_parsers import JsonOutputParser
11
+ from langchain_core.prompts import ChatPromptTemplate
12
+
13
+ from .base import register_tool, get_llm
14
+ from .utils import preprocess_messages
15
+
16
+ try:
17
+ from ..redis_client import redis_client
18
+ except (ImportError, ValueError):
19
+ from redis_client import redis_client
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Categories dưới ngưỡng này (% trên tổng) sẽ bị gom vào "others"
24
+ _OTHERS_THRESHOLD = 0.05
25
+
26
+
27
+ # ── Pydantic schema ──────────────────────────────────────────────────────────
28
+
29
+ class ChartItem(BaseModel):
30
+ label: str = Field(description="Tên danh mục / nhãn")
31
+ count: int = Field(description="Số lượng unique users có ý kiến này")
32
+
33
+ class ChartDataResponse(BaseModel):
34
+ items: List[ChartItem] = Field(description="Danh sách danh mục và số lượng unique users, sắp xếp theo count giảm dần")
35
+
36
+
37
+ # ── System prompt ────────────────────────────────────────────────────────────
38
+
39
+ _SYSTEM_PROMPT = """
40
+ Bạn là một nhà phân tích dữ liệu chuyên nghiệp.
41
+ Nhiệm vụ: Đọc tin nhắn, xác định chủ đề từ query, thống kê ý kiến của các unique users.
42
+
43
+ ══ BƯỚC 1 — XÁC ĐỊNH LOẠI DỮ LIỆU ══
44
+
45
+ Dựa vào query, phân loại dữ liệu cần thống kê:
46
+
47
+ PHÂN LOẠI (categorical): nghề nghiệp, môn học yêu thích, ngôn ngữ lập trình,
48
+ sở thích, hệ điều hành, stack công nghệ, v.v.
49
+ → Gom các giá trị tương đồng vào cùng nhãn.
50
+
51
+ SỐ (numerical): tuổi, năm kinh nghiệm, điểm GPA, mức lương, số giờ học/ngày.
52
+ → Binning thành khoảng giá trị thay vì giữ nguyên từng con số.
53
+
54
+ NHỊ PHÂN (binary): có/không, đồng ý/phản đối, nam/nữ.
55
+ → Giữ nguyên 2 nhãn, gom biến thể ("có", "yes", "ok" → "có").
56
+
57
+ ══ BƯỚC 2 — ĐẾM UNIQUE USERS ══
58
+
59
+ - Mỗi user chỉ được đếm 1 lần cho mỗi danh mục.
60
+ - Nếu user đề cập nhiều lần: lấy ý kiến RÕ RÀNG nhất (không phải đầu tiên).
61
+ - Bỏ qua tin nhắn mơ hồ, không liên quan, hoặc chỉ là phản ứng (emoji, "ok", "oke").
62
+ - Nhận diện user qua: tên hiển thị, username, hoặc sender_id — xử lý nhất quán.
63
+
64
+ ══ BƯỚC 3 — QUY TẮC THEO LOẠI ══
65
+
66
+ ▸ PHÂN LOẠI: Gom đồng nghĩa vào một nhãn chuẩn:
67
+ "SE", "software eng", "kỹ sư phần mềm" → "software engineer"
68
+ "ML", "machine learning eng" → "ml engineer"
69
+ "FE", "frontend" → "frontend developer"
70
+ Nhãn: viết thường, ngắn gọn, tiếng Anh nếu là thuật ngữ kỹ thuật.
71
+
72
+ ▸ SỐ (binning):
73
+ - Chọn kích thước bin phù hợp với độ phân tán:
74
+ Tuổi → khoảng 3–5 tuổi (VD: "18-22", "23-27", "28-32")
75
+ Kinh nghiệm → khoảng 1–2 năm (VD: "0-1 năm", "2-3 năm", "4+ năm")
76
+ GPA → khoảng 0.5 (VD: "3.0-3.5", "3.5-4.0")
77
+ - Không tạo quá 8 bin; gom đuôi nếu cần (VD: "35+" thay vì nhiều bin lẻ).
78
+ - Nhãn bin viết dạng "min-max" hoặc "min+" nếu là đuôi hở.
79
+
80
+ ▸ NHỊ PHÂN: Chuẩn hóa về đúng 2 nhãn đối lập.
81
+
82
+ ══ BƯỚC 4 — CHUẨN HÓA OUTPUT ══
83
+
84
+ - Sắp xếp theo count giảm dần.
85
+ - Loại bỏ danh mục có count = 0.
86
+ - Trả về đúng schema JSON yêu cầu.
87
+ """
88
+
89
+
90
+ # ── Tool ─────────────────────────────────────────────────────────────────────
91
+
92
+ @register_tool(
93
+ name="summarize_chart",
94
+ description=(
95
+ "Đọc tin nhắn nhóm, thống kê ý kiến của unique users theo chủ đề từ query, "
96
+ "xuất dữ liệu JSON để vẽ biểu đồ cột (column) hoặc tròn (pie). "
97
+ "Dùng khi người dùng muốn thống kê / vẽ biểu đồ từ dữ liệu trong chat."
98
+ ),
99
+ parameters=[
100
+ {"name": "query", "type": "string", "description": "Chủ đề/yêu cầu thống kê (VD: 'nghề nghiệp thành viên', 'độ tuổi').", "required": True},
101
+ {"name": "chart_type", "type": "string", "description": '"column" để vẽ biểu đồ cột, "pie" để vẽ biểu đồ tròn.', "required": True},
102
+ {"name": "conversation_id", "type": "string", "description": "ID hội thoại (conversation_id trong dmmsg, hoặc room-{id} cho phòng nhóm).", "required": False},
103
+ {"name": "room_id", "type": "string", "description": "ID phòng chat nhóm (không có prefix room-).", "required": False},
104
+ {"name": "dm_id", "type": "string", "description": "ID cuộc hội thoại DM theo Sorted Set.", "required": False},
105
+ {"name": "limit", "type": "integer", "description": "Số tin nhắn tối đa cần đọc (mặc định: 200).", "required": False},
106
+ ],
107
+ )
108
+ def tool_summarize_chart(
109
+ query: str,
110
+ chart_type: str,
111
+ messages: List[dict] = None,
112
+ conversation_id: str = None,
113
+ room_id: str = None,
114
+ dm_id: str = None,
115
+ limit: int = 200,
116
+ ) -> dict:
117
+ start_time = time.time()
118
+
119
+ chart_type = (chart_type or "column").strip().lower()
120
+ if chart_type not in ("column", "pie"):
121
+ chart_type = "column"
122
+
123
+ try:
124
+ # ── 1. Lấy tin nhắn ────────────────────────────────────────────────
125
+ if messages is None:
126
+ if conversation_id:
127
+ messages = redis_client.get_messages_by_conversation_id(conversation_id, limit)
128
+ elif room_id:
129
+ messages = redis_client.get_room_messages(room_id, limit)
130
+ elif dm_id:
131
+ messages = redis_client.get_dm_messages(dm_id, limit)
132
+ else:
133
+ return {"status": "error", "data": {"error": "Cần cung cấp conversation_id, room_id hoặc dm_id."}}
134
+
135
+ if not messages:
136
+ return {"status": "error", "data": {"error": "Không có tin nhắn để phân tích."}}
137
+
138
+ # ── 2. Gọi LLM thống kê ────────────────────────────────────────────
139
+ formatted = preprocess_messages(messages)
140
+ llm = get_llm()
141
+ parser = JsonOutputParser(pydantic_object=ChartDataResponse)
142
+
143
+ prompt = ChatPromptTemplate.from_messages([
144
+ ("system", _SYSTEM_PROMPT),
145
+ ("human", (
146
+ "Query: {query}\n\n"
147
+ "NỘI DUNG TIN NHẮN:\n{messages}\n\n"
148
+ "{format_instructions}"
149
+ )),
150
+ ])
151
+
152
+ chain = prompt | llm | parser
153
+ result = chain.invoke({
154
+ "query": query,
155
+ "messages": formatted,
156
+ "format_instructions": parser.get_format_instructions(),
157
+ })
158
+
159
+ raw_items = result.get("items", [])
160
+ if not raw_items:
161
+ return {"status": "error", "data": {"error": "Không tìm thấy dữ liệu phù hợp với query."}}
162
+
163
+ # ── 3. Format theo loại chart ──────────────────────────────────────
164
+ chart_data = _format_chart(raw_items, chart_type)
165
+
166
+ return {
167
+ "status": "success",
168
+ "chart_type": chart_type,
169
+ "chart_data": chart_data,
170
+ "total_responses": sum(i.get("count", 0) for i in raw_items),
171
+ "metrics": {"processing_time_sec": round(time.time() - start_time, 2)},
172
+ }
173
+
174
+ except Exception as e:
175
+ logger.error(f"Chart tool error: {e}")
176
+ return {"status": "error", "data": {"error": str(e)}}
177
+
178
+
179
+ # ── Helpers ──────────────────────────────────────────────────────────────────
180
+
181
+ def _format_chart(items: list[dict], chart_type: str) -> list[dict]:
182
+ total = sum(i.get("count", 0) for i in items)
183
+ if total == 0:
184
+ return []
185
+
186
+ main = [i for i in items if i.get("count", 0) / total >= _OTHERS_THRESHOLD]
187
+ others = sum(i.get("count", 0) for i in items if i.get("count", 0) / total < _OTHERS_THRESHOLD)
188
+
189
+ if chart_type == "pie":
190
+ result = [
191
+ {"label": i["label"], "percentage": round(i["count"] / total * 100, 1)}
192
+ for i in main
193
+ ]
194
+ if others:
195
+ result.append({"label": "others", "percentage": round(others / total * 100, 1)})
196
+ else: # column
197
+ result = [{"label": i["label"], "count": i["count"]} for i in main]
198
+ if others:
199
+ result.append({"label": "others", "count": others})
200
+
201
+ return result
tools/chat_tools.py CHANGED
@@ -7,12 +7,14 @@ Replaces the previous mock implementations.
7
  from . import memory as _memory_mod # noqa: F401
8
  from . import scheduler as _scheduler_mod # noqa: F401
9
  from . import summarizer as _summarizer_mod # noqa: F401
 
10
 
11
  from .base import TOOLS as _REGISTRY, get_langchain_tools
12
 
13
  _ALLOWED = {
14
  # Facilitator
15
  "summarize_chat",
 
16
  # Scheduler
17
  "get_schedule", "add_event", "update_event", "delete_event",
18
  "add_reminder", "get_reminders",
 
7
  from . import memory as _memory_mod # noqa: F401
8
  from . import scheduler as _scheduler_mod # noqa: F401
9
  from . import summarizer as _summarizer_mod # noqa: F401
10
+ from . import chart as _chart_mod # noqa: F401
11
 
12
  from .base import TOOLS as _REGISTRY, get_langchain_tools
13
 
14
  _ALLOWED = {
15
  # Facilitator
16
  "summarize_chat",
17
+ "summarize_chart",
18
  # Scheduler
19
  "get_schedule", "add_event", "update_event", "delete_event",
20
  "add_reminder", "get_reminders",