anhkhoiphan commited on
Commit
ba9644b
·
1 Parent(s): 20a314b

Hoàn thiện luồng xử lý pdf và ảnh

Browse files
Files changed (3) hide show
  1. core.py +95 -4
  2. llm.py +6 -3
  3. nodes.py +18 -1
core.py CHANGED
@@ -2,14 +2,28 @@
2
  Core agent orchestration — entry point dùng chung cho API và UI.
3
  """
4
 
 
 
5
  import time
6
  from datetime import datetime
 
 
 
7
 
8
  from src.graph import run
 
 
 
9
  from src.state import MAX_ITERS, AgentState
10
 
11
 
12
- def final_answer(conversation_id: str, sender_id: str, query: str) -> tuple[str, str]:
 
 
 
 
 
 
13
  """
14
  Khởi tạo AgentState, chạy graph, trả về (câu trả lời, thời gian xử lý).
15
 
@@ -20,8 +34,8 @@ def final_answer(conversation_id: str, sender_id: str, query: str) -> tuple[str,
20
  ValueError: nếu bất kỳ tham số bắt buộc nào rỗng.
21
  """
22
  conversation_id = conversation_id.strip()
23
- sender_id = sender_id.strip()
24
- query = query.strip()
25
 
26
  if not conversation_id:
27
  raise ValueError("conversation_id không được để trống.")
@@ -30,6 +44,72 @@ def final_answer(conversation_id: str, sender_id: str, query: str) -> tuple[str,
30
  if not query:
31
  raise ValueError("query không được để trống.")
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  initial_state: AgentState = {
34
  "conversation_id": conversation_id,
35
  "sender_id": sender_id,
@@ -43,8 +123,19 @@ def final_answer(conversation_id: str, sender_id: str, query: str) -> tuple[str,
43
  }
44
 
45
  t0 = time.perf_counter()
46
- result = run(initial_state)
47
  elapsed = f"{time.perf_counter() - t0:.2f}s"
48
 
49
  answer = result.get("final_answer") or "(Không có kết quả)"
50
  return answer, elapsed
 
 
 
 
 
 
 
 
 
 
 
 
2
  Core agent orchestration — entry point dùng chung cho API và UI.
3
  """
4
 
5
+ import base64
6
+ import mimetypes
7
  import time
8
  from datetime import datetime
9
+ from typing import Optional
10
+
11
+ from langchain_core.messages import HumanMessage, ToolMessage
12
 
13
  from src.graph import run
14
+ from src.nodes import final_response_node, image_response_node
15
+ from src.pdf_processing import format_chat_history, pdf_to_markdown
16
+ from src.redis_client import redis_client
17
  from src.state import MAX_ITERS, AgentState
18
 
19
 
20
+ def final_answer(
21
+ conversation_id: str,
22
+ sender_id: str,
23
+ query: str,
24
+ pdf_path: Optional[str] = None,
25
+ image_path: Optional[str] = None,
26
+ ) -> tuple[str, str]:
27
  """
28
  Khởi tạo AgentState, chạy graph, trả về (câu trả lời, thời gian xử lý).
29
 
 
34
  ValueError: nếu bất kỳ tham số bắt buộc nào rỗng.
35
  """
36
  conversation_id = conversation_id.strip()
37
+ sender_id = sender_id.strip()
38
+ query = query.strip()
39
 
40
  if not conversation_id:
41
  raise ValueError("conversation_id không được để trống.")
 
44
  if not query:
45
  raise ValueError("query không được để trống.")
46
 
47
+ if pdf_path is not None:
48
+ pdf_content = pdf_to_markdown(pdf_path)
49
+ chat_history = redis_client.get_chat_history(conversation_id)
50
+ chat_text = format_chat_history(chat_history)
51
+
52
+ tool_content = (
53
+ f"[Nội dung PDF]\n{pdf_content}"
54
+ f"\n\n[Lịch sử trò chuyện]\n{chat_text}"
55
+ )
56
+
57
+ state: AgentState = {
58
+ "conversation_id": conversation_id,
59
+ "sender_id": sender_id,
60
+ "time": datetime.now().isoformat(),
61
+ "raw_query": query,
62
+ "query_type": None,
63
+ "messages": [
64
+ HumanMessage(content=query),
65
+ ToolMessage(content=tool_content, tool_call_id="pdf_reader", name="pdf_reader"),
66
+ ],
67
+ "iters": 0,
68
+ "max_iters": MAX_ITERS,
69
+ "final_answer": None,
70
+ }
71
+
72
+ t0 = time.perf_counter()
73
+ result = final_response_node(state)
74
+ elapsed = f"{time.perf_counter() - t0:.2f}s"
75
+ answer = result.get("final_answer") or "(Không có kết quả)"
76
+ return answer, elapsed
77
+
78
+ if image_path is not None:
79
+ mime_type, _ = mimetypes.guess_type(image_path)
80
+ mime_type = mime_type or "image/jpeg"
81
+
82
+ with open(image_path, "rb") as f:
83
+ image_b64 = base64.b64encode(f.read()).decode()
84
+
85
+ chat_history = redis_client.get_chat_history(conversation_id)
86
+ chat_text = format_chat_history(chat_history)
87
+
88
+ text_content = f"{query}\n\n[Lịch sử trò chuyện]\n{chat_text}"
89
+
90
+ state: AgentState = {
91
+ "conversation_id": conversation_id,
92
+ "sender_id": sender_id,
93
+ "time": datetime.now().isoformat(),
94
+ "raw_query": query,
95
+ "query_type": None,
96
+ "messages": [
97
+ HumanMessage(content=[
98
+ {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{image_b64}"}},
99
+ {"type": "text", "text": text_content},
100
+ ]),
101
+ ],
102
+ "iters": 0,
103
+ "max_iters": MAX_ITERS,
104
+ "final_answer": None,
105
+ }
106
+
107
+ t0 = time.perf_counter()
108
+ result = image_response_node(state)
109
+ elapsed = f"{time.perf_counter() - t0:.2f}s"
110
+ answer = result.get("final_answer") or "(Không có kết quả)"
111
+ return answer, elapsed
112
+
113
  initial_state: AgentState = {
114
  "conversation_id": conversation_id,
115
  "sender_id": sender_id,
 
123
  }
124
 
125
  t0 = time.perf_counter()
126
+ result = run(initial_state)
127
  elapsed = f"{time.perf_counter() - t0:.2f}s"
128
 
129
  answer = result.get("final_answer") or "(Không có kết quả)"
130
  return answer, elapsed
131
+
132
+
133
+ if __name__ == "__main__":
134
+ answer, elapsed = final_answer(
135
+ conversation_id="04ba40fe-61c7-4906-9f51-5ada0a392dac",
136
+ sender_id="@slavakpa",
137
+ query="tóm tắt nội dung tài liệu này",
138
+ pdf_path="temp/test_doc.pdf",
139
+ )
140
+ print(answer)
141
+ print(f"\n({elapsed})")
llm.py CHANGED
@@ -1,17 +1,20 @@
1
  from langchain_google_genai import ChatGoogleGenerativeAI
2
  from src.config import GEMINI_API_KEY, DEFAULT_MODEL
3
 
4
- llm = ChatGoogleGenerativeAI(
5
- model=DEFAULT_MODEL,
6
  temperature=0,
7
  top_p=1,
8
  top_k=1,
9
  max_tokens=None,
10
  timeout=None,
11
  max_retries=2,
12
- google_api_key=GEMINI_API_KEY
13
  )
14
 
 
 
 
 
15
  if __name__ == "__main__":
16
  response = llm.invoke("Hello World là gì?").content
17
  print(response)
 
1
  from langchain_google_genai import ChatGoogleGenerativeAI
2
  from src.config import GEMINI_API_KEY, DEFAULT_MODEL
3
 
4
+ _base_kwargs = dict(
 
5
  temperature=0,
6
  top_p=1,
7
  top_k=1,
8
  max_tokens=None,
9
  timeout=None,
10
  max_retries=2,
11
+ google_api_key=GEMINI_API_KEY,
12
  )
13
 
14
+ llm = ChatGoogleGenerativeAI(model=DEFAULT_MODEL, **_base_kwargs)
15
+
16
+ multimodal_llm = ChatGoogleGenerativeAI(model=DEFAULT_MODEL, **_base_kwargs)
17
+
18
  if __name__ == "__main__":
19
  response = llm.invoke("Hello World là gì?").content
20
  print(response)
nodes.py CHANGED
@@ -6,7 +6,7 @@ from typing import Any, Literal
6
 
7
  from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
8
 
9
- from src.llm import llm
10
  from src.prompts import (
11
  final_response_prompt,
12
  orchestrator_prompt,
@@ -162,6 +162,23 @@ def _extract_tool_results(state: AgentState) -> str:
162
  return "\n\n".join(parts) if parts else "(Không có kết quả từ tool)"
163
 
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  def final_response_node(state: AgentState) -> AgentState:
166
  """Tổng hợp ToolMessage(s) và sinh câu trả lời cuối cùng."""
167
  logger.info("[FinalResponseNode] Tổng hợp câu trả lời...")
 
6
 
7
  from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
8
 
9
+ from src.llm import llm, multimodal_llm
10
  from src.prompts import (
11
  final_response_prompt,
12
  orchestrator_prompt,
 
162
  return "\n\n".join(parts) if parts else "(Không có kết quả từ tool)"
163
 
164
 
165
+ # ════════════════════════════════════════════════════════════════════
166
+ # NODE 6 — ImageResponseNode
167
+ # ════════════════════════════════════════════════════════════════════
168
+ def image_response_node(state: AgentState) -> AgentState:
169
+ """Nhận HumanMessage chứa ảnh + text, gọi multimodal LLM sinh câu trả lời."""
170
+ logger.info("[ImageResponseNode] Xử lý ảnh cho %s", state["sender_id"])
171
+
172
+ response = multimodal_llm.invoke(state["messages"])
173
+ answer = response.content
174
+
175
+ logger.info("[ImageResponseNode] Hoàn thành (%d ký tự)", len(answer))
176
+ return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer}
177
+
178
+
179
+ # ════════════════════════════════════════════════════════════════════
180
+ # NODE 5 — FinalResponseNode
181
+ # ════════════════════════════════════════════════════════════════════
182
  def final_response_node(state: AgentState) -> AgentState:
183
  """Tổng hợp ToolMessage(s) và sinh câu trả lời cuối cùng."""
184
  logger.info("[FinalResponseNode] Tổng hợp câu trả lời...")