File size: 6,639 Bytes
d1b26b7
 
 
 
ba9644b
d7c1cf9
c3317d5
ba9644b
d1b26b7
 
ba9644b
 
c3317d5
 
ba9644b
d1b26b7
b9b60b0
d1b26b7
ba9644b
 
b784540
dc832e8
ba9644b
d1b26b7
 
 
ba9644b
 
 
 
 
 
9ec4919
dc832e8
 
4cafa1a
d7c1cf9
d1b26b7
d7c1cf9
d1b26b7
 
d7c1cf9
 
 
d1b26b7
 
 
 
 
ba9644b
 
d1b26b7
 
 
 
 
 
 
 
dc832e8
 
 
 
 
b784540
 
4cafa1a
c3317d5
 
 
9ec4919
c3317d5
 
 
 
 
 
 
 
 
 
ba9644b
 
 
 
 
 
 
 
 
 
 
 
 
 
b784540
ba9644b
 
 
 
 
 
f462894
d7c1cf9
ba9644b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b784540
ba9644b
 
 
 
 
 
f462894
d7c1cf9
ba9644b
d1b26b7
 
 
 
 
 
 
 
 
 
b784540
d1b26b7
 
 
ba9644b
d1b26b7
 
 
d7c1cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
f462894
d7c1cf9
ba9644b
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
"""
Core agent orchestration — entry point dùng chung cho API và UI.
"""

import base64
import json
import logging
import mimetypes
import time
from datetime import datetime
from typing import Optional

logger = logging.getLogger(__name__)

from langchain_core.messages import HumanMessage, ToolMessage

from src.conversation_memory import add_turn
from src.graph import run
from src.nodes import final_response_node, image_response_node
from src.pdf_processing import format_chat_history, pdf_to_markdown
from src.qdrant_store import get_custom_prompt
from src.quiz import generate_quiz
from src.redis_client import redis_client
from src.state import MAX_ITERS, AgentState


def final_answer(
    conversation_id: str,
    sender_id: str,
    query: str,
    pdf_path: Optional[str] = None,
    image_path: Optional[str] = None,
    pdf_name: Optional[str] = None,
    gen_quiz: bool = False,
    k_question: Optional[int] = None,
    skip_pdf_indexing: bool = False,
) -> tuple[str, str, str | None, str | None]:
    """
    Khởi tạo AgentState, chạy graph, trả về 4 giá trị.

    Returns:
        (answer, elapsed, chart_type, chart_data)
        - chart_type : "column" | "pie" | None
        - chart_data : JSON string | None (chỉ có khi tool summarize_chart được gọi)

    Raises:
        ValueError: nếu bất kỳ tham số bắt buộc nào rỗng.
    """
    conversation_id = conversation_id.strip()
    sender_id       = sender_id.strip()
    query           = query.strip()

    if not conversation_id:
        raise ValueError("conversation_id không được để trống.")
    if not sender_id:
        raise ValueError("sender_id không được để trống.")
    if not query:
        raise ValueError("query không được để trống.")

    if gen_quiz:
        t0 = time.perf_counter()
        answer = generate_quiz(query, k_question or 10)
        return answer, f"{time.perf_counter() - t0:.2f}s", None, None

    custom_prompt = get_custom_prompt(sender_id)

    if pdf_path is not None and not skip_pdf_indexing:
        # Auto-index vào Qdrant để dùng với rag_search sau này (idempotent)
        try:
            from src.pdf_rag import index_pdf
            index_pdf(pdf_path, pdf_name or "document.pdf", conversation_id)
        except Exception:
            logger.exception("Auto-index PDF thất bại.")

        pdf_content  = pdf_to_markdown(pdf_path)
        chat_history = redis_client.get_chat_history(conversation_id)
        chat_text    = format_chat_history(chat_history)
        tool_content = (
            f"[Nội dung PDF]\n{pdf_content}"
            f"\n\n[Lịch sử trò chuyện]\n{chat_text}"
        )

        state: AgentState = {
            "conversation_id": conversation_id,
            "sender_id":       sender_id,
            "time":            datetime.now().isoformat(),
            "raw_query":       query,
            "query_type":      None,
            "messages":        [
                HumanMessage(content=query),
                ToolMessage(content=tool_content, tool_call_id="pdf_reader", name="pdf_reader"),
            ],
            "iters":           0,
            "max_iters":       MAX_ITERS,
            "final_answer":    None,
            "custom_prompt":   custom_prompt,
        }

        t0 = time.perf_counter()
        result  = final_response_node(state)
        elapsed = f"{time.perf_counter() - t0:.2f}s"
        answer  = result.get("final_answer") or "(Không có kết quả)"
        add_turn(conversation_id, sender_id, query, answer)
        return answer, elapsed, None, None

    if image_path is not None:
        mime_type, _ = mimetypes.guess_type(image_path)
        mime_type = mime_type or "image/jpeg"

        with open(image_path, "rb") as f:
            image_b64 = base64.b64encode(f.read()).decode()

        chat_history = redis_client.get_chat_history(conversation_id)
        chat_text    = format_chat_history(chat_history)

        text_content = f"{query}\n\n[Lịch sử trò chuyện]\n{chat_text}"

        state: AgentState = {
            "conversation_id": conversation_id,
            "sender_id":       sender_id,
            "time":            datetime.now().isoformat(),
            "raw_query":       query,
            "query_type":      None,
            "messages":        [
                HumanMessage(content=[
                    {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{image_b64}"}},
                    {"type": "text", "text": text_content},
                ]),
            ],
            "iters":           0,
            "max_iters":       MAX_ITERS,
            "final_answer":    None,
            "custom_prompt":   custom_prompt,
        }

        t0 = time.perf_counter()
        result  = image_response_node(state)
        elapsed = f"{time.perf_counter() - t0:.2f}s"
        answer  = result.get("final_answer") or "(Không có kết quả)"
        add_turn(conversation_id, sender_id, query, answer)
        return answer, elapsed, None, None

    initial_state: AgentState = {
        "conversation_id": conversation_id,
        "sender_id":       sender_id,
        "time":            datetime.now().isoformat(),
        "raw_query":       query,
        "query_type":      None,
        "messages":        [],
        "iters":           0,
        "max_iters":       MAX_ITERS,
        "final_answer":    None,
        "custom_prompt":   custom_prompt,
    }

    t0 = time.perf_counter()
    result  = run(initial_state)
    elapsed = f"{time.perf_counter() - t0:.2f}s"

    answer = result.get("final_answer") or "(Không có kết quả)"

    chart_type: str | None = None
    chart_data: str | None = None
    for msg in result.get("messages", []):
        if isinstance(msg, ToolMessage) and getattr(msg, "name", None) == "summarize_chart":
            try:
                parsed = json.loads(msg.content)
                if parsed.get("status") == "success":
                    chart_type = parsed.get("chart_type")
                    chart_data = json.dumps(parsed.get("chart_data", []), ensure_ascii=False)
            except Exception:
                pass
            break

    add_turn(conversation_id, sender_id, query, answer)
    return answer, elapsed, chart_type, chart_data


if __name__ == "__main__":
    answer, elapsed = final_answer(
        conversation_id="04ba40fe-61c7-4906-9f51-5ada0a392dac",
        sender_id="@slavakpa",
        query="tóm tắt nội dung tài liệu này",
        pdf_path="temp/test_doc.pdf",
    )
    print(answer)
    print(f"\n({elapsed})")