File size: 7,909 Bytes
bc1b4c2
 
 
 
 
 
 
ae05c68
bc1b4c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae05c68
 
 
bc1b4c2
 
 
ae05c68
bc1b4c2
 
ae05c68
bc1b4c2
ae05c68
bc1b4c2
 
ae05c68
bc1b4c2
ae05c68
bc1b4c2
 
 
 
 
 
 
 
 
 
 
ae05c68
 
 
bc1b4c2
 
 
 
 
 
 
 
 
 
ae05c68
bc1b4c2
 
ae05c68
 
 
 
bc1b4c2
 
ae05c68
bc1b4c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
try:
    __import__("pysqlite3")
    import sys
    sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
except ImportError:
    pass

import logging
import traceback
import pandas as pd
import docx2txt
import chromadb
from chromadb.config import Settings
import chainlit as cl # Thư viện giao diện mới

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_chroma import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers.ensemble import EnsembleRetriever
from langchain.chains import create_retrieval_chain, create_history_aware_retriever
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

# === CẤU HÌNH ===
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
DATA_PATH = "medical_data"
DB_PATH = "chroma_db"
MAX_HISTORY_TURNS = 4

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")

# ... (GIỮ NGUYÊN CÁC HÀM: process_excel_file, load_documents_from_folder, get_retrievers) ...
# Để tiết kiệm chỗ hiển thị, tôi không paste lại đoạn code xử lý file ở đây
# Bạn hãy copy-paste lại các hàm process_excel_file, load_documents_from_folder, get_retrievers Y HỆT NHƯ CŨ vào đây.

# === LOGIC BOT (Sửa nhẹ để tách biệt khởi tạo) ===
class DeepMedBot:
    def __init__(self):
        self.fast_chain = None
        self.deep_chain = None
        self.ready = False
        
        if not GOOGLE_API_KEY:
            return

        try:
            self.fast_retriever, self.deep_retriever = get_retrievers()
            
            self.llm = ChatGoogleGenerativeAI(
                model="gemini-2.5-flash", 
                temperature=0.2,
                google_api_key=GOOGLE_API_KEY,
                convert_system_message_to_human=True
            )
            
            if self.fast_retriever and self.deep_retriever:
                self._build_chains()
                self.ready = True
            else:
                 self.ready = True 

        except Exception as e:
            logging.error(f"Lỗi khởi tạo bot: {e}")

    def _build_chains(self):
        context_system_prompt = (
            "Dựa trên lịch sử chat và câu hỏi mới nhất, hãy viết lại câu hỏi "
            "thành một câu hoàn chỉnh để tìm kiếm thông tin. "
            "CHỈ TRẢ VỀ CÂU HỎI ĐÃ VIẾT LẠI, KHÔNG TRẢ LỜI."
        )
        context_prompt = ChatPromptTemplate.from_messages([
            ("system", context_system_prompt),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ])
        
        qa_system_prompt = (
            "Bạn là 'DeepMed-AI' - Trợ lý Dược lâm sàng chuyên nghiệp.\n"
            "Nhiệm vụ: Tư vấn điều trị CHỈ DỰA TRÊN Dữ liệu nội bộ (Context).\n"
            "QUY TẮC: Trung thực tuyệt đối, Kiểm tra thuốc có trong kho, Trích dẫn nguồn.\n"
            "Context:\n{context}"
        )
        qa_prompt = ChatPromptTemplate.from_messages([
            ("system", qa_system_prompt),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ])
        
        question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
        
        history_aware_fast = create_history_aware_retriever(self.llm, self.fast_retriever, context_prompt)
        self.fast_chain = create_retrieval_chain(history_aware_fast, question_answer_chain)
        
        history_aware_deep = create_history_aware_retriever(self.llm, self.deep_retriever, context_prompt)
        self.deep_chain = create_retrieval_chain(history_aware_deep, question_answer_chain)

    def _build_references_text(self, docs) -> str:
        lines = []
        seen = set()
        for doc in docs:
            src = doc.metadata.get("source", "Tài liệu")
            row_info = f"(Dòng {doc.metadata['row']})" if "row" in doc.metadata else ""
            ref_str = f"- {src} {row_info}"
            if ref_str not in seen:
                lines.append(ref_str)
                seen.add(ref_str)
        return "\n".join(lines)

# === PHẦN GIAO DIỆN CHAINLIT ===

@cl.on_chat_start
async def start():
    # Khởi tạo bot khi phiên chat bắt đầu
    msg = cl.Message(content="Đang khởi động hệ thống DeepMed AI...")
    await msg.send()
    
    bot = DeepMedBot()
    cl.user_session.set("bot", bot) # Lưu bot vào session của người dùng
    
    # Gửi tin nhắn chào mừng và các nút chọn chế độ
    actions = [
        cl.Action(name="fast_mode", value="fast", label="⚡ Tốc độ (Nhanh)", description="Trả lời nhanh"),
        cl.Action(name="deep_mode", value="deep", label="🧠 Chuyên sâu (Kỹ)", description="Phân tích kỹ dữ liệu")
    ]
    cl.user_session.set("mode", "fast") # Mặc định là fast
    
    await msg.update(content="👋 Xin chào! Tôi là **DeepMed AI**. \n\nTôi có thể hỗ trợ tra cứu thuốc, phác đồ điều trị và thông tin dược lâm sàng. \n\n*Mặc định đang ở chế độ: Tốc độ.*", actions=actions)

@cl.action_callback("fast_mode")
async def on_fast_mode(action):
    cl.user_session.set("mode", "fast")
    await cl.Message(content="✅ Đã chuyển sang chế độ: **Tốc độ (Nhanh)**").send()

@cl.action_callback("deep_mode")
async def on_deep_mode(action):
    cl.user_session.set("mode", "deep")
    await cl.Message(content="✅ Đã chuyển sang chế độ: **Chuyên sâu (Chậm & Chính xác)**").send()

@cl.on_message
async def main(message: cl.Message):
    bot = cl.user_session.get("bot")
    mode_setting = cl.user_session.get("mode")
    
    if not bot or not bot.ready:
        await cl.Message(content="⚠️ Hệ thống chưa sẵn sàng hoặc gặp lỗi khởi tạo.").send()
        return

    # Lấy lịch sử chat từ Chainlit
    # Chainlit tự quản lý history, nhưng để tương thích code cũ, ta có thể lấy memory
    # Ở đây ta dùng bộ nhớ session đơn giản hoặc để LangChain tự lo
    
    # Xác định chain cần dùng
    active_chain = bot.deep_chain if mode_setting == "deep" else bot.fast_chain
    
    msg = cl.Message(content="")
    await msg.send()
    
    # Lấy lịch sử chat (đơn giản hóa cho demo)
    history = [] 
    
    full_response = ""
    retrieved_docs = []

    # Gọi Chain Async
    try:
        async for chunk in active_chain.astream({"input": message.content, "chat_history": history}):
            if "answer" in chunk:
                token = chunk["answer"]
                full_response += token
                await msg.stream_token(token)
            elif "context" in chunk:
                retrieved_docs = chunk["context"]
        
        # Hiển thị nguồn tham khảo đẹp mắt hơn
        if retrieved_docs:
            refs = bot._build_references_text(retrieved_docs)
            if refs:
                ref_block = f"\n\n---\n📚 **Nguồn tham khảo:**\n{refs}"
                await msg.stream_token(ref_block)
                
        await msg.update()
        
    except Exception as e:
        await cl.Message(content=f"⚠️ Lỗi: {str(e)}").send()